线段树
算法概述
【线段树】(Segment Tree)是一种用于处理区间查询和修改操作的数据结构。它能在O(log n)的时间复杂度内实现单点更新和区间查询,如求和、求最大值、求最小值等操作。
算法设计思路
线段树的核心思想是:
- 将原始数组的区间划分为多个子区间
- 使用平衡二叉树表示这些区间,每个节点代表一个区间
- 父节点的区间是其左右子节点区间的并集
- 通过递归方式构建树结构和执行查询/更新操作
代码实现与解析
线段树的基本结构
// 线段树节点结构
struct SegmentTreeNode {
int start, end; // 区间范围
int sum; // 区间和(或其他聚合信息,如最大值、最小值)
SegmentTreeNode* left;
SegmentTreeNode* right;
SegmentTreeNode(int start, int end) {
this->start = start;
this->end = end;
this->sum = 0;
this->left = nullptr;
this->right = nullptr;
}
};
线段树的构建
// 构建线段树
SegmentTreeNode* buildTree(vector<int>& nums, int start, int end) {
if (start > end) return nullptr;
SegmentTreeNode* root = new SegmentTreeNode(start, end);
// 叶子节点
if (start == end) {
root->sum = nums[start];
return root;
}
// 非叶子节点,递归构建左右子树
int mid = start + (end - start) / 2;
root->left = buildTree(nums, start, mid);
root->right = buildTree(nums, mid + 1, end);
// 计算当前节点的区间和(根据需求可以是其他聚合操作)
root->sum = root->left->sum + root->right->sum;
return root;
}
区间查询
// 查询区间[queryStart, queryEnd]的和
int querySum(SegmentTreeNode* root, int queryStart, int queryEnd) {
// 边界情况
if (!root || queryStart > root->end || queryEnd < root->start)
return 0;
// 查询区间完全覆盖当前节点区间
if (queryStart <= root->start && queryEnd >= root->end)
return root->sum;
// 查询区间与当前节点区间部分重叠,需要分别查询左右子树
int leftSum = querySum(root->left, queryStart, queryEnd);
int rightSum = querySum(root->right, queryStart, queryEnd);
return leftSum + rightSum;
}
单点更新
// 更新索引为index的元素值为val
void update(SegmentTreeNode* root, int index, int val) {
// 找到要更新的叶子节点
if (root->start == root->end && root->start == index) {
root->sum = val; // 更新叶子节点的值
return;
}
int mid = root->start + (root->end - root->start) / 2;
// 判断index在左子树还是右子树,并递归更新
if (index <= mid) {
update(root->left, index, val);
} else {
update(root->right, index, val);
}
// 更新当前节点的区间和
root->sum = root->left->sum + root->right->sum;
}
数组实现的线段树
class SegmentTree {
private:
vector<int> tree; // 存储线段树的数组
vector<int> lazy; // 懒标记数组,用于区间更新
int n; // 原始数组的长度
// 构建线段树
void build(vector<int>& nums, int node, int start, int end) {
if (start == end) {
tree[node] = nums[start];
return;
}
int mid = start + (end - start) / 2;
int leftChild = 2 * node + 1;
int rightChild = 2 * node + 2;
build(nums, leftChild, start, mid);
build(nums, rightChild, mid + 1, end);
tree[node] = tree[leftChild] + tree[rightChild];
}
// 更新单个节点值
void updateSingle(int node, int start, int end, int idx, int val) {
if (start == end) {
tree[node] = val;
return;
}
int mid = start + (end - start) / 2;
int leftChild = 2 * node + 1;
int rightChild = 2 * node + 2;
if (idx <= mid) {
updateSingle(leftChild, start, mid, idx, val);
} else {
updateSingle(rightChild, mid + 1, end, idx, val);
}
tree[node] = tree[leftChild] + tree[rightChild];
}
// 查询区间和
int queryRange(int node, int start, int end, int left, int right) {
if (right < start || left > end) {
return 0; // 查询区间在当前节点区间外
}
if (left <= start && end <= right) {
return tree[node]; // 当前节点区间完全在查询区间内
}
// 处理懒标记(如果有的话)
int mid = start + (end - start) / 2;
int leftChild = 2 * node + 1;
int rightChild = 2 * node + 2;
int sumLeft = queryRange(leftChild, start, mid, left, right);
int sumRight = queryRange(rightChild, mid + 1, end, left, right);
return sumLeft + sumRight;
}
// 区间更新(带懒标记)
void updateRange(int node, int start, int end, int left, int right, int val) {
// 先处理之前留下的懒标记
if (lazy[node] != 0) {
tree[node] += (end - start + 1) * lazy[node]; // 更新当前节点
if (start != end) { // 非叶子节点,传递懒标记给子节点
lazy[2*node+1] += lazy[node];
lazy[2*node+2] += lazy[node];
}
lazy[node] = 0; // 重置当前懒标记
}
// 没有重叠
if (start > end || start > right || end < left) return;
// 当前区间完全包含在目标区间内
if (left <= start && end <= right) {
tree[node] += (end - start + 1) * val;
if (start != end) { // 非叶子节点,添加懒标记
lazy[2*node+1] += val;
lazy[2*node+2] += val;
}
return;
}
// 部分重叠,继续递归处理
int mid = start + (end - start) / 2;
updateRange(2*node+1, start, mid, left, right, val);
updateRange(2*node+2, mid+1, end, left, right, val);
// 更新当前节点
tree[node] = tree[2*node+1] + tree[2*node+2];
}
public:
SegmentTree(vector<int>& nums) {
n = nums.size();
// 分配4*n的空间,确保足够
tree.resize(4 * n, 0);
lazy.resize(4 * n, 0);
if (n > 0) {
build(nums, 0, 0, n - 1);
}
}
void update(int idx, int val) {
updateSingle(0, 0, n - 1, idx, val);
}
void updateRange(int left, int right, int val) {
updateRange(0, 0, n - 1, left, right, val);
}
int query(int left, int right) {
return queryRange(0, 0, n - 1, left, right);
}
};
懒标记优化
懒标记是线段树的一种优化技术,用于处理区间更新操作。其核心思想是延迟更新:当需要更新某个区间时,先不更新该区间所有节点,而是在父节点标记一个懒标记,等到真正需要查询到该节点时再执行更新。
// 带懒标记的区间更新
void updateRangeWithLazy(SegmentTreeNode* root, int updateStart, int updateEnd, int val) {
// 首先处理当前节点的懒标记
if (root->lazy != 0) {
root->sum += (root->end - root->start + 1) * root->lazy; // 更新当前节点
if (root->start != root->end) { // 非叶子节点,向下传递懒标记
if (root->left) root->left->lazy += root->lazy;
if (root->right) root->right->lazy += root->lazy;
}
root->lazy = 0; // 重置当前懒标记
}
// 无交集
if (updateStart > root->end || updateEnd < root->start) return;
// 当前区间完全包含在更新区间中
if (updateStart <= root->start && updateEnd >= root->end) {
root->sum += (root->end - root->start + 1) * val; // 更新当前节点
if (root->start != root->end) { // 非叶子节点,设置子节点的懒标记
if (root->left) root->left->lazy += val;
if (root->right) root->right->lazy += val;
}
return;
}
// 部分重叠,继续递归更新
if (root->left) updateRangeWithLazy(root->left, updateStart, updateEnd, val);
if (root->right) updateRangeWithLazy(root->right, updateStart, updateEnd, val);
// 更新当前节点
root->sum = (root->left ? root->left->sum : 0) + (root->right ? root->right->sum : 0);
}
// 带懒标记的区间查询
int queryWithLazy(SegmentTreeNode* root, int queryStart, int queryEnd) {
// 无交集
if (!root || queryStart > root->end || queryEnd < root->start) return 0;
// 首先处理当前节点的懒标记
if (root->lazy != 0) {
root->sum += (root->end - root->start + 1) * root->lazy; // 更新当前节点
if (root->start != root->end) { // 非叶子节点,向下传递懒标记
if (root->left) root->left->lazy += root->lazy;
if (root->right) root->right->lazy += root->lazy;
}
root->lazy = 0; // 重置当前懒标记
}
// 当前区间完全包含在查询区间中
if (queryStart <= root->start && queryEnd >= root->end) return root->sum;
// 部分重叠,递归查询左右子树
int leftSum = root->left ? queryWithLazy(root->left, queryStart, queryEnd) : 0;
int rightSum = root->right ? queryWithLazy(root->right, queryStart, queryEnd) : 0;
return leftSum + rightSum;
}
复杂度分析
- 构建线段树的时间复杂度:O(n),其中n是原始数组的长度
- 查询操作的时间复杂度:O(log n)
- 更新操作的时间复杂度:O(log n)
- 区间更新操作的时间复杂度:O(log n)(使用懒标记)
- 空间复杂度:O(n)(数组实现通常是4*n)
典型应用场景
线段树广泛应用于需要频繁执行区间操作的场景:
- 【区间求和】:在指定区间内元素的和
- 【区间最值】:查找区间内的最大值或最小值
- 【区间更新】:将区间内所有元素增加、减少或设置为某个值
- 【区间统计】:统计区间内满足某条件的元素个数
典型例题分析
例题1: 区间和查询
问题描述:给定一个数组,支持两种操作:
- 更新某个位置的值
- 查询区间[i, j]的和
class NumArray {
private:
vector<int> tree;
int n;
void build(vector<int>& nums, int node, int start, int end) {
if (start == end) {
tree[node] = nums[start];
return;
}
int mid = start + (end - start) / 2;
int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;
build(nums, leftNode, start, mid);
build(nums, rightNode, mid + 1, end);
tree[node] = tree[leftNode] + tree[rightNode];
}
void updateTree(int node, int start, int end, int idx, int val) {
if (start == end) {
tree[node] = val;
return;
}
int mid = start + (end - start) / 2;
int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;
if (idx <= mid) {
updateTree(leftNode, start, mid, idx, val);
} else {
updateTree(rightNode, mid + 1, end, idx, val);
}
tree[node] = tree[leftNode] + tree[rightNode];
}
int queryTree(int node, int start, int end, int left, int right) {
if (left > end || right < start) {
return 0;
}
if (left <= start && right >= end) {
return tree[node];
}
int mid = start + (end - start) / 2;
int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;
int leftSum = queryTree(leftNode, start, mid, left, right);
int rightSum = queryTree(rightNode, mid + 1, end, left, right);
return leftSum + rightSum;
}
public:
NumArray(vector<int>& nums) {
n = nums.size();
tree.resize(4 * n);
build(nums, 0, 0, n - 1);
}
void update(int index, int val) {
updateTree(0, 0, n - 1, index, val);
}
int sumRange(int left, int right) {
return queryTree(0, 0, n - 1, left, right);
}
};
分析:
- 时间复杂度:构建O(n),查询和更新均为O(log n)
- 空间复杂度:O(n)
- 核心思想:使用线段树高效处理区间操作,适合频繁的查询和更新操作组合
例题2: 区间最小值查询
问题描述:给定一个数组,支持查询区间[i, j]内的最小值。
class RangeMinQuery {
private:
vector<int> tree;
int n;
void build(vector<int>& nums, int node, int start, int end) {
if (start == end) {
tree[node] = nums[start];
return;
}
int mid = start + (end - start) / 2;
int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;
build(nums, leftNode, start, mid);
build(nums, rightNode, mid + 1, end);
tree[node] = min(tree[leftNode], tree[rightNode]); // 存储区间最小值
}
int queryMin(int node, int start, int end, int left, int right) {
if (left > end || right < start) {
return INT_MAX; // 查询区间在当前节点区间外,返回最大值作为哨兵
}
if (left <= start && right >= end) {
return tree[node]; // 当前节点区间完全在查询区间内
}
int mid = start + (end - start) / 2;
int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;
int leftMin = queryMin(leftNode, start, mid, left, right);
int rightMin = queryMin(rightNode, mid + 1, end, left, right);
return min(leftMin, rightMin);
}
public:
RangeMinQuery(vector<int>& nums) {
n = nums.size();
tree.resize(4 * n);
build(nums, 0, 0, n - 1);
}
int getMin(int left, int right) {
return queryMin(0, 0, n - 1, left, right);
}
};
分析:
- 时间复杂度:查询O(log n)
- 空间复杂度:O(n)
- 核心思想:存储区间最小值而非区间和,函数逻辑基本相同,只是聚合方式改变
易错点与调试技巧
【区间重叠判定】正确判断查询/更新区间与节点区间的重叠关系
// 错误写法:条件不完整 if (left > end || right < start) { /* 无交集 */ } else if (left == start && right == end) { /* 完全包含 */ } // 正确写法 if (left > end || right < start) { /* 无交集 */ } else if (left <= start && end <= right) { /* 完全包含 */ }
【懒标记传递】确保正确传递和清除懒标记
// 错误写法:忘记传递懒标记给子节点 if (lazy[node] != 0) { tree[node] += (end - start + 1) * lazy[node]; lazy[node] = 0; // 直接清除而未传递 } // 正确写法 if (lazy[node] != 0) { tree[node] += (end - start + 1) * lazy[node]; if (start != end) { // 非叶子节点 lazy[2*node+1] += lazy[node]; lazy[2*node+2] += lazy[node]; } lazy[node] = 0; // 清除当前节点的懒标记 }
【更新后的回溯】在递归更新操作后,别忘记更新父节点值
// 错误写法:忘记回溯更新父节点 void update(int node, int start, int end, int idx, int val) { if (start == end) { tree[node] = val; return; } int mid = start + (end - start) / 2; if (idx <= mid) update(2*node+1, start, mid, idx, val); else update(2*node+2, mid+1, end, idx, val); // 缺少这一行 // tree[node] = tree[2*node+1] + tree[2*node+2]; }
【边界索引】注意线段树的索引计算,特别是在数组实现中
// 常见错误:子节点索引计算错误 int leftChild = 2 * node; // 错误 int rightChild = 2 * node + 1; // 错误 // 正确写法(数组从0开始索引) int leftChild = 2 * node + 1; int rightChild = 2 * node + 2;
优化与扩展
动态线段树
当区间范围非常大,但实际需要处理的元素较少时,可以使用动态线段树,只为实际存在的元素分配节点。
// 动态线段树的节点
struct DynamicSegmentTreeNode {
int start, end;
int sum;
DynamicSegmentTreeNode *left, *right;
DynamicSegmentTreeNode(int start, int end) {
this->start = start;
this->end = end;
this->sum = 0;
this->left = this->right = nullptr;
}
};
// 动态查询区间和
int querySum(DynamicSegmentTreeNode* root, int queryStart, int queryEnd) {
if (!root || queryStart > root->end || queryEnd < root->start) return 0;
if (queryStart <= root->start && queryEnd >= root->end) return root->sum;
int mid = root->start + (root->end - root->start) / 2;
// 延迟创建子节点
int leftSum = 0, rightSum = 0;
if (queryStart <= mid && root->left) {
leftSum = querySum(root->left, queryStart, queryEnd);
}
if (queryEnd > mid && root->right) {
rightSum = querySum(root->right, queryStart, queryEnd);
}
return leftSum + rightSum;
}
// 动态更新点值
void update(DynamicSegmentTreeNode*& root, int index, int val) {
if (!root) return;
if (root->start == root->end && root->start == index) {
root->sum = val;
return;
}
int mid = root->start + (root->end - root->start) / 2;
// 延迟创建子节点
if (index <= mid) {
if (!root->left) {
root->left = new DynamicSegmentTreeNode(root->start, mid);
}
update(root->left, index, val);
} else {
if (!root->right) {
root->right = new DynamicSegmentTreeNode(mid + 1, root->end);
}
update(root->right, index, val);
}
root->sum = (root->left ? root->left->sum : 0) + (root->right ? root->right->sum : 0);
}
持久化线段树
持久化线段树(也称为函数式线段树或可持久化线段树)允许保存线段树的每个历史版本,适用于需要查询历史状态的场景。
// 持久化线段树的节点
struct PersistentSegTreeNode {
int sum;
PersistentSegTreeNode *left, *right;
PersistentSegTreeNode() : sum(0), left(nullptr), right(nullptr) {}
PersistentSegTreeNode(PersistentSegTreeNode* l, PersistentSegTreeNode* r) :
left(l), right(r), sum(0) {
if (l) sum += l->sum;
if (r) sum += r->sum;
}
};
// 创建新版本的线段树
PersistentSegTreeNode* update(PersistentSegTreeNode* prev, int start, int end, int idx, int val) {
if (start == end) {
PersistentSegTreeNode* node = new PersistentSegTreeNode();
node->sum = val;
return node;
}
int mid = start + (end - start) / 2;
PersistentSegTreeNode *left_child = prev ? prev->left : nullptr;
PersistentSegTreeNode *right_child = prev ? prev->right : nullptr;
if (idx <= mid) {
left_child = update(left_child, start, mid, idx, val);
} else {
right_child = update(right_child, mid + 1, end, idx, val);
}
return new PersistentSegTreeNode(left_child, right_child);
}
// 查询特定版本的线段树
int query(PersistentSegTreeNode* root, int start, int end, int left, int right) {
if (!root || left > end || right < start) return 0;
if (left <= start && end <= right) return root->sum;
int mid = start + (end - start) / 2;
return query(root->left, start, mid, left, right) +
query(root->right, mid + 1, end, left, right);
}
练习题推荐
- LeetCode 307: 区域和检索 - 数组可修改(基础线段树)
- LeetCode 315: 计算右侧小于当前元素的个数(线段树应用)
- SPOJ HORRIBLE: Horrible Queries(区间更新和查询)
- Codeforces 339D: Xenia and Bit Operations(线段树变种)
- Codeforces 242E: XOR on Segment(线段树与位运算结合)
- POJ 2991: Crane(持久化线段树)
- UVA 11402: Ahoy, Pirates!(线段树进阶应用)
- SPOJ KQUERY: K-query(持久化线段树应用)
总结
线段树是一种强大的数据结构,能够高效处理各种区间操作问题。它的核心优势在于对于区间查询和更新操作都能保持O(log n)的时间复杂度。通过掌握线段树的基本原理和各种变体(如懒标记、动态线段树、持久化线段树),你可以解决许多复杂的区间操作问题。
虽然线段树实现相对复杂,但它提供的效率提升在许多竞赛问题中是不可替代的。记住,线段树的关键是正确理解区间的划分和节点的管理,以及在查询和更新操作中维护节点信息的一致性。