线段树

本文最后更新于:2022年8月25日 下午

之前那个实在是垃圾到不能再垃圾了……今天必须重写一份。

这个写的会稍微快一点,因为只有一个数据结构,而且也没有大量的 $\LaTeX$,对我也是相当的友好啊。

如有学术问题,请毫不留情地在评论区予以指正。

Definition

线段树是一种基于二分思想的二叉树结构,用于在区间上进行信息统计。比起树状数组来用途更加广泛。

实现

关于线段树,你需要知道这些常识:

  1. 线段树的每个节点都代表一个区间。
  2. 线段树具有唯一的根节点,代表的区间是整个统计范围(即题目中让你维护的序列)。
  3. 线段树的每个叶子节点代表的区间是一个长度为 1 的区间。
  4. 对于每个内部节点 $[l, r]$,其左子节点为 $[l, mid]$,右子节点为 $[mid + 1, r]$,其中 $ mid = \left\lfloor\ (l+r)/2 \right\rfloor$。

我们可以发现,除了最后一层以外,线段树是一颗完全二叉树,所以我们可以用父子二倍的方法给其编号。树的深度为 $\mathcal{O}(\log N)$。

  1. 根节点为 1
  2. 编号为 $x$ 的左子节点为 $x * 2$,右节点为 $x * 2 + 1$。

如此一来,直接一个 struct 数组干上去,存储掉就可以了,多出的空间不必管它空着就行了,为了防止空间上的爆炸,请至少把保存线段树的数组开到 $N$ 的 4 倍(L_fire 奆佬说最好 8 倍或 16 倍,这个大家自己度德量力一下吧)

建树

要致富,先撸树。

我们建线段树就是为了维护序列,并且支持快速的查询与修改。给定一个长度为 $N$ 的序列 $A$,我们用每个叶子节点保存 $A$ 中的每个值。二叉树结构上下传递信息非常给力。下面让我们看看如何建树:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
struct Tree {
int l, r; // 左端点,右端点
int dat; // 该区间维护的值
}t[800005];

void build(int p, int l, int r) { // p 代表当前的区间
t[p].l = l, t[p].r = r; // 把它的区间弄进来
if(l == r) { // 找到了叶子节点,赋值走人
t[p].dat = a[l];
return ;
}

int mid = (l + r) / 2;
build(p * 2, l, mid); // 左子树
build(p * 2 + 1, mid + 1, r); // 右子树
t[p].dat = max(t[p * 2].dat, t[p * 2 + 1].dat); // 更新答案
}

单点修改

我们知道,如果序列上随便改个点,那么这一堆区间的最值都得跟着遭殃,我们亟需快速的方法进行修改。

不过不用胆怯,线段树直接爆踩这些东西。

线段树中,根节点是任何操作的入口,我们从根节点一路递归到修改对应的叶子节点上,修改它的值,然后一路往上修改这个叶子节点的所有父亲节点,这样就完成了修改,时间复杂度为 $\mathcal{O}(\log N)$。

上代码:

1
2
3
4
5
6
7
8
9
10
11
12
void change(int p, int x, int v) { // p 是当前所在区间,x 是修改的位置,v 是修改的值 
if(t[p].l == t[p].r) { // 找到了 x 的叶子节点,我们直接修改
t[p].dat = v;
return ;
}
int mid = (t[p].l + t[p].r) / 2;

if(x <= mid) change(p * 2, x, v); // 如果修改的地方在左子树,去找那个节点
else change(p * 2 + 1, x, v); // 反之则去右子树找

t[p].dat = max(t[p * 2].dat, t[p * 2 + 1].dat); // 更新值
}

区间查询

这种问题一般问我们在一个区间中的最大值,我们利用线段树去解决此问题。

从根节点开始,我们执行以下递归过程:

  1. 如果要求区间完全覆盖掉了当前的区间,我们直接回溯并把此区间当做候选答案
  2. 若左子节点与要求区间有重合,我们递归访问左子节点
  3. 右子节点与要求区间有重合,我们递归访问右子节点

上代码:

1
2
3
4
5
6
7
8
9
10
int ask(int p, int l, int r) {
if(l <= t[p].l && t[p].r <= r) // 恰好覆盖
return t[p].dat;

int mid = (t[p].l + t[p].r) / 2;
int val = -0x3f3f3f3f;
if(l <= mid) val = max(val, ask(p * 2, l, r)); // 更新值
if(r > mid) val = max(val, ask(p * 2 + 1, l ,r)); // 更新值
return val;
}

这样的查询过程会将要求区间分成 $\log N$ 个节点,取最值作为答案。

$$现在,我们成功地写出了一个线段树!$$

区间修改

我们知道,很多题目都是要求我们进行区间修改的,这种时候如果沿用之前的修改方式时间复杂度会变成 $\mathcal{O}(n)$,这超出了我们的承受范围。

考虑一个问题,如果我们修改了该节点和其所有的子树节点,然而这些子树上的节点一个也没有用过,那我们修改它们就是徒劳的。

注意到这个性质之后,我们可以在该节点上打一个标记,修改该节点后并不对其子树上的节点开刀,等用到这个子节点的时候我们再向下给它更新值去,这个标记我们称其为延时标记。

换言之,延时标记的意义为:该节点被修改过,但其子节点未被修改。

在后续的操作中,如果需要子树上的节点的话,我们从 $p$ 开始递归,检查其是否带有标记,如果有的话就更新两个子节点,然后给子节点打上标记,最后清除 $p$ 的标记。

我们可以发现,每条查询与修改皆变为了 $\mathcal{O}(\log n)$。

好的,来几个例题吧!

例题

https://www.acwing.com/problem/content/description/244/

这是个线段树的裸题,我们直接上代码,我相信你们能看懂。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include <cstdio>
#include <algorithm>
#include <string>
#include <iostream>
#define ll long long
#define l(x) t[x].l
#define r(x) t[x].r
#define sum(x) t[x].sum
#define add(x) t[x].add
using namespace std;

struct Node {
int l, r;
ll sum, add;
}t[500005];

int a[100005], n, m;

void build(int p, int l, int r) {
l(p) = l, r(p) = r;
if(l == r) {
sum(p) = a[l];
return ;
}

int mid = (l + r) / 2;
build(p * 2, l, mid);
build(p * 2 + 1, mid + 1, r);
sum(p) = sum(p * 2) + sum(p * 2 + 1);
}

void spr(int p) {
if(add(p)) {
sum(p * 2) += add(p) * (r(p * 2) - l(p * 2) + 1);
sum(p * 2 + 1) += add(p) * (r(p * 2 + 1) - l(p * 2 + 1) + 1);
add(p * 2) += add(p);
add(p * 2 + 1) += add(p);
add(p) = 0;
}
}

void change(int p, int l, int r, int d) {
if(l <= l(p) && r >= r(p)) {
sum(p) += (ll)d * (r(p) - l(p) + 1);
add(p) += d;
return ;
}

spr(p);
int mid = (l(p) + r(p)) / 2;
if(l <= mid) change(p * 2, l, r, d);
if(r > mid) change(p * 2 + 1, l, r, d);
sum(p) = sum(p * 2) + sum(p * 2 + 1);
}

ll ask(int p, int l, int r) {
if(l <= l(p) && r >= r(p)) return sum(p);
spr(p);

int mid = (l(p) + r(p)) / 2;
ll val = 0;
if(l <= mid) val += ask(p * 2, l, r);
if(r > mid) val += ask(p * 2 + 1, l, r);

return val;
}

int main() {
scanf("%d%d", &n, &m);
for(int i=1; i<=n; i++)
scanf("%d", &a[i]);

build(1, 1, n);
while(m--) {
string s;
int l, r, d;
cin >> s >> l >> r;
if(s == "C") {
scanf("%d", &d);
change(1, l, r, d);
} else printf("%lld\n", ask(1, l, r));
}
return 0;
}

https://www.luogu.com.cn/problem/P3372

跟上边这题一样,自己改改去就行了。

https://www.acwing.com/problem/content/246/

这道题是个单点修改的,不过要注意最大连续子段和的求法:

这里我们除了 $l, r$ 以外再维护 4 个信息:

区间和 $\operatorname{sum}$,区间最大连续子段和 $\operatorname{dat}$,紧靠左端的最大连续子段和 $\operatorname{lmax}$,紧靠右端的最大连续子段和 $\operatorname{rmax}$。

如何求 $\operatorname{dat}$ 呢?首先我们需要两个子节点的区间和来更新此节点的区间和,随后更新紧靠左端与紧靠右端的最大连续子段和,更新方式较为类似,这里介绍一种方法另一种大家自己依葫芦画瓢就行了:

更新紧靠左端的最大连续子段和:从左子树的紧靠左端的最大连续子段和与左子树的紧靠右端的最大连续子段和加上右子树的区间和中取最大值。

最后,我们更新好了这两个东西之后我们就用左右子树的最大连续子段和与左子树的紧靠右端的最大连续子段和加上右子树紧靠左端的最大连续子段和这三个数中取最大值,就是我们要的最大连续子段和了。

贴代码(这份代码是我和 L_fire 大佬一起写的,所以码风不大一样,不过大部分是我自己编写的):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include <cstdio>
#include <algorithm>
using namespace std;

struct Node {
int sum, lmax, rmax, dat;
int l, r;
}t[6000005];

int n, m,inf=-0x3f3f3f3f;
int a[500005];
void build(int p, int l, int r) {
t[p].l = l, t[p].r = r;
if(l == r) {
t[p].sum = a[l];
t[p].lmax = a[l];
t[p].rmax = a[l];
t[p].dat = a[l];
return ;
}

int mid = (l + r) / 2;
build(p * 2, l, mid);
build(p * 2 + 1, mid + 1, r);
t[p].sum = t[p * 2].sum + t[p * 2 + 1].sum;
t[p].lmax = max(t[p * 2].lmax, t[p * 2].sum + t[p * 2 + 1].lmax);
t[p].rmax = max(t[p * 2 + 1].rmax, t[p * 2 + 1].sum + t[p * 2].rmax);
t[p].dat = max(t[p * 2].dat, max(t[p * 2 + 1].dat, t[p * 2].rmax + t[p * 2 + 1].lmax));
}

void change(int p, int x, int y) {
if(t[p].l == t[p].r) {
t[p].sum = y;
t[p].lmax = y;
t[p].rmax = y;
t[p].dat = y;
return ;
}

int mid = (t[p].l + t[p].r) / 2;
if(x <= mid) change(p * 2, x, y);
else change(p * 2 + 1, x, y);
t[p].sum = t[p * 2].sum + t[p * 2 + 1].sum;
t[p].lmax = max(t[p * 2].lmax, t[p * 2].sum + t[p * 2 + 1].lmax);
t[p].rmax = max(t[p * 2 + 1].rmax, t[p * 2 + 1].sum + t[p * 2].rmax);
t[p].dat = max(t[p * 2].dat, max(t[p * 2 + 1].dat, t[p * 2].rmax + t[p * 2 + 1].lmax));
}

Node ask(int p, int l, int r) {
if(l<=t[p].l&&r>=t[p].r) return t[p];
int mid=(t[p].l+t[p].r)>>1;
if(r <= mid) return ask(p * 2, l, r);
else if(l > mid) return ask(p * 2 + 1, l, r);
else {
Node t1 = ask(p * 2, l, mid), t2 = ask(p * 2 + 1, mid + 1, r), t3;
t3.l = l, t3.r = r;
t3.sum = t1.sum + t2.sum;
t3.lmax = max(t1.lmax, t1.sum + t2.lmax);
t3.rmax = max(t2.rmax, t2.sum + t1.rmax);
t3.dat = max(t1.dat, max(t2.dat, t1.rmax + t2.lmax));
return t3;
}

}

int main() {
scanf("%d%d", &n, &m);
for(int i=1; i<=n; i++)
scanf("%d", &a[i]);

build(1, 1, n);
while(m--) {
int k, x, y;
scanf("%d%d%d", &k, &x, &y);
if(k == 1) {
if(x > y) swap(x, y);
printf("%d\n", ask(1, x, y).dat);
} else
change(1, x, y);

}
return 0;
}

https://www.luogu.com.cn/problem/P3373

我们发现这个题要维护两种标记,乘法与加法,也就是说我们需要两个传递函数。

不过我们需要注意到一个东西:这里更新乘法标记的时候也要更新加法标记,但是更新加法标记的时候不能更新乘法标记。

我觉得都看了这么多代码了,自己写一次应该可以吧。

https://www.luogu.com.cn/problem/P4588

这道题第一眼看上去根本不知道在干些什么……分析分析看。

我们看的出来,现在的问题是如何快速地找到第 $pos$ 次操作所乘的数,那么明显我们要维护所有的 $m$。

不妨转换思想,用离线的方式来考虑问题(好吧其实是半离线),当遇上一个要求乘法操作的时候我们就将其存入一个数组中,然后拿这个数组来建树,(并记录下这个数对应的次数),每次出现一个除法操作我们就去二分查找次数从而确定它的数值,然后将这个数值修改成 1(想想看,为什么?)。

为什么说是半离线操作呢?因为我们发现每次的输出都是有要求的,即我们需要以当前的元素个数进行乘法处理,乘完之后我们直接输出根节点的元素解决问题,因此掌控好合适的区间是必要的,这就是我们需要一个变量来控制区间的原因。

行了行了上代码(这份也是我和 L_fire 一起编写的,码风差异请自行忽略):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <cstdio>
#include <algorithm>
#include <iostream>
#define ll long long
#define l(x) t[x].l
#define r(x) t[x].r
#define sum(x) t[x].sum
using namespace std;

struct Node {
int l, r;
ll add, sum;
}t[800005];

struct Node1{
int op,m;
}c[1000005];
ll mod;
int a[100005], b[100005];

void build(int p, int l, int r) {
l(p) = l, r(p) = r;
if(l == r) {
sum(p) = a[l];
return ;
}

int mid = (l + r) / 2;
build(p * 2, l ,mid);
build(p * 2 + 1, mid + 1, r);

sum(p) = sum(p * 2) * sum(p * 2 + 1) % mod;
}

void change(int p, int l, int r) {
if(l(p) == r(p)) {
sum(p) = 1;
return ;
}

int mid = (l(p) + r(p)) / 2;

if(l <= mid) change(p * 2, l, r);
if(r > mid) change(p * 2 + 1, l, r);

sum(p) = sum(p * 2) * sum(p * 2 + 1) % mod;
}

int ask(int p, int l, int r) {
if(l <= l(p) && r >= r(p))
return sum(p);

int mid = (l(p) + r(p)) / 2;
ll val = 1;
if(r > mid) val = val * ask(p * 2 + 1, l, r) % mod;
if(l <= mid) val = val * ask(p * 2, l, r) % mod;

return val % mod;
}

int q, T;

int main() {
scanf("%d", &T);
while(T--) {
cin>>q>>mod;
int cnt=0;
for(int i=1;i<=q;i++)
{
cin>>c[i].op>>c[i].m;
if(c[i].op==1) {
a[++cnt] = c[i].m;
b[cnt] = i;
}
}
build(1, 1, cnt);
int top = 0;
for(int i=1; i<=q; i++)
{
if(c[i].op == 1) top++;
else {
int num = lower_bound(b + 1, b + 1 + cnt, c[i].m) - b;
change(1, num, num);
}
cout << ask(1, 1, top) << endl;
}
}

return 0;
}

线段树
http://dxrprime.github.io/2022/08/25/线段树/
作者
Shelter Prime
发布于
2022年8月25日
许可协议