线段树Segmentation Tree

Posted by Meng Cao on 2019-06-01

定义

一棵balanced binary tree —> log(n)高度
给定start, end, 求[start,end]所有数的和/min/max
高效解决连续区间的动态查询问题,由于二叉结构的特性,它基本能保持每个操作的复杂度为O(logn)。
线段树的每个节点表示一个区间,子节点则分别表示父节点的左右半区间,例如父亲的区间是[a,b],那么(c=(a+b)/2)左儿子的区间是[a,c],右儿子的区间是[c+1,b]。

原理

每个leaf负责一个元素的值
每个parent负责的范围是他的children所负责的范围的union,并把所有范围内的元素值相加。
同一层的节点没有overlap。
root存储的是所有元素的和。
所以一个SegmentTreeNode需要记录以下信息

1
2
3
4
5
6
start #起始范围
end #终止范围
mid #拆分点,通常是 (start + end) // 2
val #所有子元素的和
left #左子树
right #右子树

C++ 版本的Segmentation类如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class SegmentTreeNode {
public:
SegmentTreeNode(int start, int end, int sum,
SegmentTreeNode* left = nullptr,
SegmentTreeNode* right = nullptr):
start(start),
end(end),
sum(sum),
left(left),
right(right){}//列表初始化
//SegmentTreeNode(const SegmentTreeNode&) = delete;
//SegmentTreeNode& operator=(const SegmentTreeNode&) = delete;
~SegmentTreeNode(){
delete left;
delete right;
left = right = nullptr;
}

int start;
int end;
int sum;
SegmentTreeNode* left;
SegmentTreeNode* right;
};

BulidTree

递归的构造
T(n) = 2*T(n/2) = O(n)

输入参数:start通常取0, end通常是元素的个数。

1
2
3
4
5
6
7
8
9
SegmentTreeNode* buildTree(int start, int end){
if(start==end)
return new SegmentTreeNode(start, end, nums_[start]);
int mid = start + (end - start) / 2;
auto left = buildTree(start, mid);
auto right = buildTree(mid+1, end);
auto node = new SegmentTreeNode(start, end, left->sum + right->sum, left, right);
return node;
}

Update Tree

线段树可以动态更新,但是元素个数不能变。

1
2
3
4
5
6
7
8
9
10
11
12
void updateTree(SegmentTreeNode* root, int i, int val){
if(root->start==i && root->end==i){
root->sum = val;
return;
}
int mid = root->start + (root->end - root->start) / 2;
if(i <= mid) //要更新的节点在左子树
updateTree(root->left,i,val);
else //要更新的节点在右子树
updateTree(root->right,i,val);
root->sum = root->left->sum + root->right->sum;
}

Query

O(log n + k), k为访问到的节点的数量。

1
2
3
4
5
6
7
8
9
10
11
12
# 求index i ~ j 的元素和
int sumRange(SegmentTreeNode* root, int i, int j){
if(i==root->start && j==root->end)
return root->sum;
int mid = root->start + (root->end - root->start) / 2;
if(j<=mid)
return sumRange(root->left, i, j);
else if(i>mid)
return sumRange(root->right,i, j);
else
return sumRange(root->left,i,mid) + sumRange(root->right,mid+1,j);
}

习题

307. Range Sum Query - Mutable

308. Range Sum Query 2D - Mutable