forked from TheAlgorithms/Python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsegment_tree.py
More file actions
177 lines (146 loc) · 5.1 KB
/
segment_tree.py
File metadata and controls
177 lines (146 loc) · 5.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""Segment Tree Data Structure.
A Segment Tree is a binary tree used for storing intervals or segments.
It allows querying which of the stored segments contain a given point.
Typically used for range queries and updates.
Time Complexity:
- Build: O(n)
- Query: O(log n)
- Update: O(log n)
Space Complexity: O(n)
"""
from typing import Callable
class SegmentTree:
"""Segment Tree implementation for range queries.
This implementation supports range sum queries and point updates.
Can be extended to support other operations like min/max queries.
Attributes:
tree: List storing the segment tree nodes
n: Size of the input array
operation: Function to combine two values (default: addition)
>>> st = SegmentTree([1, 3, 5, 7, 9, 11])
>>> st.query(1, 3)
15
>>> st.update(1, 10)
>>> st.query(1, 3)
22
>>> st.query(0, 5)
42
>>> st2 = SegmentTree([2, 4, 6, 8], operation=min)
>>> st2.query(0, 3)
2
>>> st2.update(0, 10)
>>> st2.query(0, 3)
4
"""
def __init__(
self, arr: list[int], operation: Callable[[int, int], int] = lambda a, b: a + b
) -> None:
"""Initialize segment tree with given array.
Args:
arr: Input array of integers
operation: Binary operation to combine values (default: addition)
>>> st = SegmentTree([1, 2, 3])
>>> len(st.tree)
8
"""
self.n = len(arr)
self.tree = [0] * (4 * self.n) # Allocate space for segment tree
self.operation = operation
self._build(arr, 0, 0, self.n - 1)
def _build(self, arr: list[int], node: int, start: int, end: int) -> None:
"""Build segment tree recursively.
Args:
arr: Input array
node: Current node index in tree
start: Start index of current segment
end: End index of current segment
"""
if start == end:
# Leaf node
self.tree[node] = arr[start]
else:
mid = (start + end) // 2
left_child = 2 * node + 1
right_child = 2 * node + 2
self._build(arr, left_child, start, mid)
self._build(arr, right_child, mid + 1, end)
self.tree[node] = self.operation(
self.tree[left_child], self.tree[right_child]
)
def query(self, left: int, right: int) -> int:
"""Query for value in range [left, right].
Args:
left: Left boundary of query range (inclusive)
right: Right boundary of query range (inclusive)
Returns:
Result of applying operation over the range
>>> st = SegmentTree([1, 2, 3, 4, 5])
>>> st.query(0, 2)
6
>>> st.query(2, 4)
12
"""
return self._query(0, 0, self.n - 1, left, right)
def _query(self, node: int, start: int, end: int, left: int, right: int) -> int:
"""Recursive helper for range query.
Args:
node: Current node index
start: Start of current segment
end: End of current segment
left: Query left boundary
right: Query right boundary
Returns:
Query result for current segment
"""
if right < start or left > end:
# No overlap
return 0 if self.operation(0, 0) == 0 else float('inf')
if left <= start and end <= right:
# Complete overlap
return self.tree[node]
# Partial overlap
mid = (start + end) // 2
left_child = 2 * node + 1
right_child = 2 * node + 2
left_result = self._query(left_child, start, mid, left, right)
right_result = self._query(right_child, mid + 1, end, left, right)
return self.operation(left_result, right_result)
def update(self, index: int, value: int) -> None:
"""Update value at given index.
Args:
index: Index to update
value: New value
>>> st = SegmentTree([1, 2, 3, 4, 5])
>>> st.query(0, 4)
15
>>> st.update(2, 10)
>>> st.query(0, 4)
22
"""
self._update(0, 0, self.n - 1, index, value)
def _update(self, node: int, start: int, end: int, index: int, value: int) -> None:
"""Recursive helper for point update.
Args:
node: Current node index
start: Start of current segment
end: End of current segment
index: Index to update
value: New value
"""
if start == end:
# Leaf node
self.tree[node] = value
else:
mid = (start + end) // 2
left_child = 2 * node + 1
right_child = 2 * node + 2
if index <= mid:
self._update(left_child, start, mid, index, value)
else:
self._update(right_child, mid + 1, end, index, value)
self.tree[node] = self.operation(
self.tree[left_child], self.tree[right_child]
)
if __name__ == "__main__":
import doctest
doctest.testmod()