Summary of the Segment Tree Data Structure
Consider we have n points indexed from 1 to n, each storing a numeric value. We often need to perform updates or aggregate queries over the continuous subrange [L, R]. The segment tree enables both update and query operations in O(log n) time complexity.
The core idea of a segment tree is to recursively split the entire range [1, n] into disjoint subintervals, with a total number of nodes no more than 4n. Any arbitrary range [L, R] can be decomposed into a small number of these pre-built subintervals, so we can get the final result by combining results from these decomposed subintervals.
For this decomposition approach to work, the aggregated value we query must satisfy the interval combination property: the result of a parent interval can be computed directly from the results of its left and right child intervals. Common examples that meet this requireemnt:
- Sum: Total sum of the interval = sum of left interval + sum of right interval
- Maximum value: Total maximum of the interval =
max(left interval maximum, right interval maximum)
Implementation: Segment Tree with Range Add, Range Multiply, and Range Sum Query
Below is a step-by-step implementation with lazy propagation to handle both range update types efficiently. The lazy tags follow a priority rule where multiplication updates are applied before addition updates. When propagating pending updates to children, we use the formula: new_child_sum = current_child_sum * parent_mul_lazy + parent_add_lazy * child_length to correctly accumulate pending changes.
Build the Tree
void build(int left, int right, int node_idx) {
tree[node_idx].range_left = left;
tree[node_idx].range_right = right;
tree[node_idx].mul_lazy = 1;
tree[node_idx].add_lazy = 0;
if (left == right) {
tree[node_idx].sum = input_arr[left] % MOD;
return;
}
int mid = (left + right) / 2;
build(left, mid, 2 * node_idx);
build(mid + 1, right, 2 * node_idx + 1);
push_up(node_idx);
}
Push Up Operation
Combine results from child nodes to update the current node's value:
void push_up(int node_idx) {
int left_child = 2 * node_idx;
int right_child = 2 * node_idx + 1;
tree[node_idx].sum = (tree[left_child].sum + tree[right_child].sum) % MOD;
}
Push Down Operation
Propagate pending lazy updates to child nodes:
void push_down(int node_idx) {
int curr_l = tree[node_idx].range_left;
int curr_r = tree[node_idx].range_right;
int mid = (curr_l + curr_r) / 2;
int left_child = 2 * node_idx;
int right_child = 2 * node_idx + 1;
// Update left child
tree[left_child].sum = (tree[left_child].sum * tree[node_idx].mul_lazy % MOD
+ tree[node_idx].add_lazy * (mid - curr_l + 1)) % MOD;
tree[right_child].sum = (tree[right_child].sum * tree[node_idx].mul_lazy % MOD
+ tree[node_idx].add_lazy * (curr_r - mid)) % MOD;
tree[left_child].mul_lazy = (tree[left_child].mul_lazy * tree[node_idx].mul_lazy) % MOD;
tree[right_child].mul_lazy = (tree[right_child].mul_lazy * tree[node_idx].mul_lazy) % MOD;
tree[left_child].add_lazy = (tree[left_child].add_lazy * tree[node_idx].mul_lazy + tree[node_idx].add_lazy) % MOD;
tree[right_child].add_lazy = (tree[right_child].add_lazy * tree[node_idx].mul_lazy + tree[node_idx].add_lazy) % MOD;
// Reset current node's lazy tags
tree[node_idx].add_lazy = 0;
tree[node_idx].mul_lazy = 1;
}
Range Addition Update
void range_add(int node_idx, int query_l, int query_r, long long val) {
int curr_l = tree[node_idx].range_left;
int curr_r = tree[node_idx].range_right;
if (curr_l > query_r || curr_r < query_l) return;
if (curr_l >= query_l && curr_r <= query_r) {
tree[node_idx].add_lazy = (tree[node_idx].add_lazy + val) % MOD;
tree[node_idx].sum = (tree[node_idx].sum + val * (curr_r - curr_l + 1)) % MOD;
return;
}
push_down(node_idx);
range_add(2 * node_idx, query_l, query_r, val);
range_add(2 * node_idx + 1, query_l, query_r, val);
push_up(node_idx);
}
Range Multiplication Update
void range_mul(int node_idx, int query_l, int query_r, long long val) {
int curr_l = tree[node_idx].range_left;
int curr_r = tree[node_idx].range_right;
if (curr_l > query_r || curr_r < query_l) return;
if (curr_l >= query_l && curr_r <= query_r) {
tree[node_idx].mul_lazy = (tree[node_idx].mul_lazy * val) % MOD;
tree[node_idx].add_lazy = (tree[node_idx].add_lazy * val) % MOD;
tree[node_idx].sum = (tree[node_idx].sum * val) % MOD;
return;
}
push_down(node_idx);
range_mul(2 * node_idx, query_l, query_r, val);
range_mul(2 * node_idx + 1, query_l, query_r, val);
push_up(node_idx);
}
Range Sum Query
long long query_sum(int node_idx, int query_l, int query_r) {
int curr_l = tree[node_idx].range_left;
int curr_r = tree[node_idx].range_right;
if (curr_l > query_r || curr_r < query_l) return 0;
if (curr_l >= query_l && curr_r <= query_r) return tree[node_idx].sum;
push_down(node_idx);
return (query_sum(2 * node_idx, query_l, query_r) + query_sum(2 * node_idx + 1, query_l, query_r)) % MOD;
}
Full Working Template
#include <iostream>
#include <cstdio>
using namespace std;
typedef long long ll;
const int MAX_N = 100020;
ll MOD;
struct SegNode {
ll sum;
ll add_lazy;
ll mul_lazy;
int range_left;
int range_right;
} tree[MAX_N * 4];
ll input_arr[MAX_N];
int n, m;
void push_up(int node_idx) {
int left = 2 * node_idx;
int right = 2 * node_idx + 1;
tree[node_idx].sum = (tree[left].sum + tree[right].sum) % MOD;
}
void build(int left, int right, int node_idx) {
tree[node_idx].range_left = left;
tree[node_idx].range_right = right;
tree[node_idx].mul_lazy = 1;
tree[node_idx].add_lazy = 0;
if (left == right) {
tree[node_idx].sum = input_arr[left] % MOD;
return;
}
int mid = (left + right) / 2;
build(left, mid, 2 * node_idx);
build(mid + 1, right, 2 * node_idx + 1);
push_up(node_idx);
}
void push_down(int node_idx) {
int curr_l = tree[node_idx].range_left;
int curr_r = tree[node_idx].range_right;
int mid = (curr_l + curr_r) / 2;
int left_child = 2 * node_idx;
int right_child = 2 * node_idx + 1;
tree[left_child].sum = (tree[left_child].sum * tree[node_idx].mul_lazy % MOD
+ tree[node_idx].add_lazy * (mid - curr_l + 1)) % MOD;
tree[right_child].sum = (tree[right_child].sum * tree[node_idx].mul_lazy % MOD
+ tree[node_idx].add_lazy * (curr_r - mid)) % MOD;
tree[left_child].mul_lazy = (tree[left_child].mul_lazy * tree[node_idx].mul_lazy) % MOD;
tree[right_child].mul_lazy = (tree[right_child].mul_lazy * tree[node_idx].mul_lazy) % MOD;
tree[left_child].add_lazy = (tree[left_child].add_lazy * tree[node_idx].mul_lazy + tree[node_idx].add_lazy) % MOD;
tree[right_child].add_lazy = (tree[right_child].add_lazy * tree[node_idx].mul_lazy + tree[node_idx].add_lazy) % MOD;
tree[node_idx].add_lazy = 0;
tree[node_idx].mul_lazy = 1;
}
void range_add(int node_idx, int l, int r, ll val) {
int curr_l = tree[node_idx].range_left;
int curr_r = tree[node_idx].range_right;
if (curr_l > r || curr_r < l) return;
if (curr_l >= l && curr_r <= r) {
tree[node_idx].add_lazy = (tree[node_idx].add_lazy + val) % MOD;
tree[node_idx].sum = (tree[node_idx].sum + val * (curr_r - curr_l + 1)) % MOD;
return;
}
push_down(node_idx);
range_add(2 * node_idx, l, r, val);
range_add(2 * node_idx + 1, l, r, val);
push_up(node_idx);
}
void range_mul(int node_idx, int l, int r, ll val) {
int curr_l = tree[node_idx].range_left;
int curr_r = tree[node_idx].range_right;
if (curr_l > r || curr_r < l) return;
if (curr_l >= l && curr_r <= r) {
tree[node_idx].mul_lazy = (tree[node_idx].mul_lazy * val) % MOD;
tree[node_idx].add_lazy = (tree[node_idx].add_lazy * val) % MOD;
tree[node_idx].sum = (tree[node_idx].sum * val) % MOD;
return;
}
push_down(node_idx);
range_mul(2 * node_idx, l, r, val);
range_mul(2 * node_idx + 1, l, r, val);
push_up(node_idx);
}
ll query_sum(int node_idx, int l, int r) {
int curr_l = tree[node_idx].range_left;
int curr_r = tree[node_idx].range_right;
if (curr_l > r || curr_r < l) return 0;
if (curr_l >= l && curr_r <= r) return tree[node_idx].sum;
push_down(node_idx);
return (query_sum(2 * node_idx, l, r) + query_sum(2 * node_idx + 1, l, r)) % MOD;
}
int main() {
scanf("%d%d%lld", &n, &m, &MOD);
for (int i = 1; i <= n; i++) {
scanf("%lld", &input_arr[i]);
}
build(1, n, 1);
for (int i = 1; i <= m; i++) {
int op, x, y;
ll k;
scanf("%d", &op);
switch(op) {
case 1:
scanf("%d%d%lld", &x, &y, &k);
range_mul(1, x, y, k);
break;
case 2:
scanf("%d%d%lld", &x, &y, &k);
range_add(1, x, y, k);
break;
case 3:
scanf("%d%d", &x, &y);
printf("%lld\n", query_sum(1, x, y));
break;
}
}
return 0;
}