线段树学习

最近的leetcode周赛中,经常出现线段树,加之今天每日一题有线段树,因此进行学习与训练

用途、操作、结构

1.用途:

  • 维护区间信息
  • 维护区间修改
  • 时间复杂度为logn
    直观的图形结构

维护线段树需要的操作

  • pushup:由子节点计算父节点的信息;
  • pushdown:把当前父节点的修改信息下传到子节点,也被称为懒标记(延迟标记);这个操作比较复杂,一般不涉及到区间修改则不用写。
  • build:将一段区间初始化成线段树;
  • modify:修改操作,分为两类:① 单点修改(需要使用pushup),② 区间修改(需要使用pushdown);
  • query:查询一段区间的值

对于不同的题目,具体查找的值也不尽相同,但是要求维护的值必须满足结合律,即同样的操作,即使顺序不同,得到的结果也是相同的。

V1. 区间和+区间修改+区间查询

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import List

class SegTree:
def __init__(self, siz, data):
self.tree = [0] * (siz*4+4)
self.lazy = [0] * (siz*4+4)
self.data = data

def build(self, l, r, p=1):
if l == r:
self.tree[p] = self.data[l-1]
return
mid = (l+r)//2
self.build(l, mid, p*2)
self.build(mid+1, r, p*2+1)
self.tree[p] = self.tree[p * 2] + self.tree[p * 2 + 1]

def push_down(self,p ,len):
self.lazy[p*2] += self.lazy[p]
self.lazy[p*2+1] += self.lazy[p] # lazy传递下去
self.tree[p*2] += self.lazy[p] * (len- len//2) # 但是值的更新需要取决于长度
self.tree[p*2+1] += self.lazy[p] * (len//2)
self.lazy[p] = 0 # 已经传递下去了,所以现在没有待传递的值

def query(self, l, r, ql, qr, p = 1):
if qr < l or ql > r: # 查询范围在区间外,返回0
return 0
if ql >= l and qr <= r: # 查询范围完全在区间内,返回内容
return self.tree[p]
else: #二分查找到目标区间
mid = (ql+qr)//2
self.push_down(p, qr-ql+1)
return self.query(l,r,ql,mid,p*2) \
+ self.query(l,r,mid+1,qr,p*2+1)

def update(self, l, r, ql, qr, data, p=1):
if qr < l or ql > r: # 查询范围在区间外,返回
return
if ql >= l and qr <= r: # 完全覆盖区间
self.tree[p] += (qr-ql+1) * data # 首先求值
if qr > ql: # 非叶子节点,向下传递
self.lazy[p] += data
else:#覆盖了一部分,递归继续更新
mid = (ql + qr) // 2
self.push_down(p,qr-ql+1)
self.update(l,r,ql,mid,data,p*2)
self.update(l,r,mid+1,qr,data,p*2+1)

self.tree[p] = self.tree[p * 2] + self.tree[p * 2 + 1]


class NumArray:

def __init__(self, nums: List[int]):
self.tree = SegTree(len(nums), nums)
self.tree.build(1,len(nums))
self.l = 1
self.r = len(nums)
self.nums = nums
print(self.tree.tree)


def update(self, index: int, val: int) -> None:
self.tree.update(index+1,index+1,self.l,self.r,val-self.nums[index])
self.nums[index] = val


def sumRange(self, left: int, right: int) -> int:
left += 1
right += 1
print(f"q {left}-{right}")
return self.tree.query(left,right,self.l, self.r,1)



V1.1 区间和+单点修改+区间查询

在原始代码中,我们已经在 update 方法的最后一行隐式地执行了 push_up 操作
将这一行代码抽象为一个单独的 push_up 方法并在需要的地方调用它,可以使代码更具可读性和可维护性。这使得代码结构更加清晰,便于理解和修改。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from typing import List

class SegTree:
def __init__(self, siz, data):
self.tree = [0] * (siz*4+4)
self.lazy = [0] * (siz*4+4)
self.data = data

def build(self, l, r, p=1):
if l == r:
self.tree[p] = self.data[l-1]
return
mid = (l+r)//2
self.build(l, mid, p*2)
self.build(mid+1, r, p*2+1)
self.push_up(p)

def push_up(self, p):
self.tree[p] = self.tree[p * 2] + self.tree[p * 2 + 1]

def push_down(self, p, len):
self.lazy[p*2] += self.lazy[p]
self.lazy[p*2+1] += self.lazy[p] # lazy传递下去
self.tree[p*2] += self.lazy[p] * (len - len//2) # 但是值的更新需要取决于长度
self.tree[p*2+1] += self.lazy[p] * (len//2)
self.lazy[p] = 0 # 已经传递下去了,所以现在没有待传递的值

def query(self, l, r, ql, qr, p = 1):
if qr < l or ql > r: # 查询范围在区间外,返回0
return 0
if ql >= l and qr <= r: # 查询范围完全在区间内,返回内容
return self.tree[p]
else: #二分查找到目标区间
mid = (ql+qr)//2
self.push_down(p, qr-ql+1)
return self.query(l,r,ql,mid,p*2) \
+ self.query(l,r,mid+1,qr,p*2+1)

def update(self, l, r, ql, qr, data, p=1):
if qr < l or ql > r: # 查询范围在区间外,返回
return
if ql >= l and qr <= r: # 完全覆盖区间
self.tree[p] += (qr-ql+1) * data # 首先求值
if qr > ql: # 非叶子节点,向下传递
self.lazy[p] += data
else: #覆盖了一部分,递归继续更新
mid = (ql + qr) // 2
self.push_down(p,qr-ql+1)
self.update(l,r,ql,mid,data,p*2)
self.update(l,r,mid+1,qr,data,p*2+1)
self.push_up(p)

V2 区间最值线段树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import List

class SegTree:
def __init__(self, siz, data):
self.tree = [0] * (siz*4+4)
self.lazy = [0] * (siz*4+4)
self.data = data

def build(self, l, r, p=1):
if l == r:
self.tree[p] = self.data[l-1]
return
mid = (l+r)//2
self.build(l, mid, p*2)
self.build(mid+1, r, p*2+1)
self.push_up(p)

def push_up(self, p):
self.tree[p] = max(self.tree[p * 2], self.tree[p * 2 + 1])

def push_down(self, p, len):
self.lazy[p*2] += self.lazy[p]
self.lazy[p*2+1] += self.lazy[p] # lazy传递下去
self.tree[p*2] += self.lazy[p]
self.tree[p*2+1] += self.lazy[p]
self.lazy[p] = 0 # 已经传递下去了,所以现在没有待传递的值

def query(self, l, r, ql, qr, p = 1):
if qr < l or ql > r: # 查询范围在区间外,返回负无穷大(用于比较最大值)
return float('-inf')
if ql >= l and qr <= r: # 查询范围完全在区间内,返回内容
return self.tree[p]
else: #二分查找到目标区间
mid = (ql+qr)//2
self.push_down(p, qr-ql+1)
return max(self.query(l,r,ql,mid,p*2), self.query(l,r,mid+1,qr,p*2+1))

def update(self, l, r, ql, qr, data, p=1):
if qr < l or ql > r: # 查询范围在区间外,返回
return
if ql >= l and qr <= r: # 完全覆盖区间
self.tree[p] = data # 更新为新的最大值
if qr > ql: # 非叶子节点,向下传递
self.lazy[p] += data - self.tree[p]
else: #覆盖了一部分,递归继续更新
mid = (ql + qr) // 2
self.push_down(p,qr-ql+1)
self.update(l,r,ql,mid,data,p*2)
self.update(l,r,mid+1,qr,data,p*2+1)
self.push_up(p)

v2.1 最值区间动态开点线段树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class Node:
def __init__(self, l, r):
self.left = None
self.right = None
self.l = l
self.r = r
self.mid = (l + r) >> 1
self.v = 0
self.add = 0


class SegmentTree:
def __init__(self,l,r):
self.root = Node(l, int(r))

def modify(self, l, r, v, node=None):
if l > r:
return
if node is None:
node = self.root
if node.l >= l and node.r <= r:
node.v += v
node.add += v
return
self.pushdown(node)
if l <= node.mid:
self.modify(l, r, v, node.left)
if r > node.mid:
self.modify(l, r, v, node.right)
self.pushup(node)

def query(self, l, r, node=None):
if l > r:
return 0
if node is None:
node = self.root
if node.l >= l and node.r <= r:
return node.v
self.pushdown(node)
v = 0
if l <= node.mid:
v = max(v, self.query(l, r, node.left))
if r > node.mid:
v = max(v, self.query(l, r, node.right))
self.pushup(node)
return v

def pushup(self, node):
node.v = max(node.left.v, node.right.v)

def pushdown(self, node):
if node.left is None:
node.left = Node(node.l, node.mid)
if node.right is None:
node.right = Node(node.mid + 1, node.r)
if node.add:
node.left.v += node.add
node.right.v += node.add
node.left.add += node.add
node.right.add += node.add
node.add = 0


class MyCalendarThree:

def __init__(self):
self.seg = SegmentTree(1, pow(10,9)+1)

def book(self, startTime: int, endTime: int) -> int:
self.seg.modify(startTime+2, endTime+1, 1)
return self.seg.query(1, int(pow(10,9)+1))

v3 链表实现:动态开点+区间和+区间最大值+区间范围修改

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import random
class Node:
def __init__(self, l, r):
self.left = None
self.right = None
self.l = l
self.r = r
self.mid = (l + r) >> 1
self.v = 0
self.sum = 0
self.add = 0

class SegmentTree:
def __init__(self, l, r):
self.root = Node(l, int(r))

def modify(self, l, r, v, node=None):
if l > r:
return
if node is None:
node = self.root
if node.l >= l and node.r <= r:
node.v += v
node.add += v
node.sum += (node.r - node.l + 1) * v
return
self.pushdown(node)
if l <= node.mid:
self.modify(l, r, v, node.left)
if r > node.mid:
self.modify(l, r, v, node.right)
self.pushup(node)

def query(self, l, r, node=None, query_type="max"):
if l > r:
return 0 if query_type == "max" else 0
if node is None:
node = self.root
if node.l >= l and node.r <= r:
return node.v if query_type == "max" else node.sum
self.pushdown(node)

if query_type == "max":
v = float('-inf')
if l <= node.mid:
v = max(v, self.query(l, r, node.left, query_type))
if r > node.mid:
v = max(v, self.query(l, r, node.right, query_type))
else: # query_type == "sum"
v = 0
if l <= node.mid:
v += self.query(l, r, node.left, query_type)
if r > node.mid:
v += self.query(l, r, node.right, query_type)

self.pushup(node)
return v

def pushup(self, node):
node.v = max(node.left.v, node.right.v)
node.sum = node.left.sum + node.right.sum

def pushdown(self, node):
if node.left is None:
node.left = Node(node.l, node.mid)
if node.right is None:
node.right = Node(node.mid + 1, node.r)
if node.add:
node.left.v += node.add
node.right.v += node.add
node.left.sum += (node.left.r - node.left.l + 1) * node.add
node.right.sum += (node.right.r - node.right.l + 1) * node.add
node.left.add += node.add
node.right.add += node.add
node.add = 0

def test_segment_tree():
n = 10**2
arr = [0] * (n + 1)
tree = SegmentTree(1, n)

# Perform 100 random modify operations
for _ in range(100):
l = random.randint(1, n)
r = random.randint(l, n)
v = random.randint(-100, 100)

for i in range(l, r + 1):
arr[i] += v

tree.modify(l, r, v)

# Perform 100 random query operations for both max and sum
for _ in range(100):
l = random.randint(1, n)
r = random.randint(l, n)

max_val = max(arr[l:r + 1])
max_query = tree.query(l, r, query_type="max")
assert max_val == max_query, f"Expected max {max_val}, but got {max_query}"

sum_val = sum(arr[l:r + 1])
sum_query = tree.query(l, r, query_type="sum")
assert sum_val == sum_query, f"Expected sum {sum_val}, but got {sum_query}"

print("All tests passed!")

test_segment_tree()

v3.1 伪链表实现:动态开点+区间和+区间最大值+区间范围修改

v3.2 dict实现:动态开点+区间和+区间最大值+区间范围修改

1

题目

模板题307. 区域和检索 - 数组可修改

leetcode 699.掉落的方块


线段树学习
https://tech.jasonczc.cn/2023/algorithm/ds/interval_searching/segment_tree/
作者
CZY
发布于
2023年5月5日
许可协议