跳转至

主席树(可持久化线段树)

概述

主席树(Persistent Segment Tree)是一种可持久化数据结构,能够保存所有历史版本,支持在任意历史版本上进行查询和修改。通过路径复制技术,每次修改只新增 O(log n) 个节点,空间复杂度为 O(n log n)。

命名由来:由黄嘉泰(网名"主席")引入国内竞赛圈,故称"主席树"。本质是可持久化权值线段树。

主席树特点

特性 说明
可持久化 保留所有历史版本,可回溯
空间高效 路径复制,每次修改 O(log n) 空间
版本查询 支持任意版本上的查询
函数式 操作不修改原结构,返回新版本
静态查询 静态区间第 K 小的经典解法

可持久化原理

路径复制

修改时,只复制从根到修改点的路径,其他节点共享:

原线段树
A
B
D
E
C
修改节点 D
A'
B'
D'
E
C
版本 0: A
版本 1: A' (A'、B'、D' 新建,C、E 共享)

路径复制可视化

graph TB
    subgraph 版本0
        A0(("A"))
        B0(("B"))
        C0(("C"))
        D0(("D"))
        E0(("E"))
        
        A0 --> B0 --> D0
        B0 --> E0
        A0 --> C0
    end
    
    subgraph 版本1
        A1(("A'"))
        B1(("B'"))
        C1(("C"))
        D1(("D'"))
        E1(("E"))
        
        A1 --> B1 --> D1
        B1 --> E1
        A1 --> C1
    end
    
    style A0 fill:#4CAF50,color:#fff
    style A1 fill:#2196F3,color:#fff
    style D1 fill:#FF9800,color:#fff
空间分析:每次修改创建 O(log n) 个新节点。n 次修改后,总节点数为 O(n log n)。

数据结构

C
#define MAXN 200005
#define LOG 20

typedef struct PersistNode {
    int count;        // 区间内元素个数
    int left;         // 左子节点索引
    int right;        // 右子节点索引
} PersistNode;

typedef struct {
    PersistNode nodes[MAXN * LOG];  // 节点池
    int roots[MAXN];                 // 各版本的根节点
    int nodeCount;                   // 节点计数
    int versionCount;                // 版本计数
} PersistentSegmentTree;

创建主席树

C
PersistentSegmentTree* createPersistTree(int n) {
    PersistentSegmentTree *tree = (PersistentSegmentTree*)calloc(1, sizeof(PersistentSegmentTree));
    tree->nodeCount = 0;
    tree->versionCount = 0;
    tree->roots[0] = 0;
    return tree;
}

int newNode(PersistentSegmentTree *tree) {
    int id = tree->nodeCount++;
    tree->nodes[id].count = 0;
    tree->nodes[id].left = 0;
    tree->nodes[id].right = 0;
    return id;
}

构建主席树

C
int build(PersistentSegmentTree *tree, int left, int right) {
    int node = newNode(tree);
    
    if (left == right) {
        return node;
    }
    
    int mid = (left + right) / 2;
    tree->nodes[node].left = build(tree, left, mid);
    tree->nodes[node].right = build(tree, mid + 1, right);
    
    return node;
}

void initPersistTree(PersistentSegmentTree *tree, int n) {
    tree->roots[0] = build(tree, 1, n);
    tree->versionCount = 1;
}

插入操作(创建新版本)

C
int insert(PersistentSegmentTree *tree, int prevRoot, 
           int left, int right, int pos, int value) {
    // 创建新节点,复制旧节点信息
    int node = newNode(tree);
    tree->nodes[node] = tree->nodes[prevRoot];
    tree->nodes[node].count += value;
    
    if (left == right) {
        return node;
    }
    
    int mid = (left + right) / 2;
    
    if (pos <= mid) {
        // 修改左子树,右子树共享
        tree->nodes[node].left = insert(tree, tree->nodes[prevRoot].left, 
                                        left, mid, pos, value);
    } else {
        // 修改右子树,左子树共享
        tree->nodes[node].right = insert(tree, tree->nodes[prevRoot].right, 
                                         mid + 1, right, pos, value);
    }
    
    return node;
}

void addVersion(PersistentSegmentTree *tree, int pos, int value, int n) {
    int prevRoot = tree->roots[tree->versionCount - 1];
    tree->roots[tree->versionCount] = insert(tree, prevRoot, 1, n, pos, value);
    tree->versionCount++;
}

插入过程示意

初始: 空树 → 插入位置 3,值 +1
执行步骤:
步骤1: 创建根节点 root[1],复制 root[0]
步骤2: 进入左子树 [1,4],创建新节点
步骤3: 进入右子树 [3,4],创建新节点
步骤4: 到达叶子 [3,3],count = 1
新创建的节点(带 * 标记):
        [1,4]*
       /    \
    [1,2]   [3,4]*
            /  \
          [3,3]* [4,4]
* 表示新建节点,其他节点与前一版本共享

查询操作

C
int query(PersistentSegmentTree *tree, int root, 
          int left, int right, int ql, int qr) {
    if (ql > right || qr < left) return 0;
    if (ql <= left && right <= qr) {
        return tree->nodes[root].count;
    }
    
    int mid = (left + right) / 2;
    return query(tree, tree->nodes[root].left, left, mid, ql, qr) +
           query(tree, tree->nodes[root].right, mid + 1, right, ql, qr);
}

int queryVersion(PersistentSegmentTree *tree, int version, 
                 int ql, int qr, int n) {
    return query(tree, tree->roots[version], 1, n, ql, qr);
}

静态区间第 K 小

原理

利用前缀和思想:区间 [l, r] 的信息 = 版本 r 的信息 - 版本 (l-1) 的信息

graph TB
    subgraph 区间第K小原理
        A["版本 r: root r"]
        B["版本 l-1: root l-1"]
        C["差值 = 区间 l,r 的统计"]
        D["二分查找第 K 小"]
        
        A --> C
        B --> C
        C --> D
    end
    
    style D fill:#E8F5E9

实现

C
int queryKth(PersistentSegmentTree *tree, int rootLeft, int rootRight,
             int left, int right, int k) {
    if (left == right) return left;
    
    int mid = (left + right) / 2;
    
    // 左子树元素个数 = 版本 r 的左子树 - 版本 l-1 的左子树
    int leftCount = tree->nodes[tree->nodes[rootRight].left].count - 
                    tree->nodes[tree->nodes[rootLeft].left].count;
    
    if (leftCount >= k) {
        // 第 k 小在左子树
        return queryKth(tree, tree->nodes[rootLeft].left, 
                       tree->nodes[rootRight].left, left, mid, k);
    } else {
        // 第 k 小在右子树
        return queryKth(tree, tree->nodes[rootLeft].right,
                       tree->nodes[rootRight].right, mid + 1, right, k - leftCount);
    }
}

int kthSmallest(PersistentSegmentTree *tree, int l, int r, int k, int n) {
    return queryKth(tree, tree->roots[l - 1], tree->roots[r], 1, n, k);
}

查询示例

数组: [3, 1, 4, 1, 5]
离散化后: [2, 1, 3, 1, 4] (值域 1-4)
询问: 区间 [2, 5] 第 3 小
版本构建:
版本 1 (处理 A[1]=3): root[1]
版本 2 (处理 A[2]=1): root[2]
版本 3 (处理 A[3]=4): root[3]
版本 4 (处理 A[4]=1): root[4]
版本 5 (处理 A[5]=5): root[5]
查询区间 [2, 5]:
rootLeft = root[1]
rootRight = root[5]
步骤1: 查找第 3 小
左子树 [1,2] 元素数 = (root[5] 左子树) - (root[1] 左子树) = 2 - 0 = 2
2 < 3,第 3 小在右子树 [3,4]
新 k = 3 - 2 = 1
步骤2: 查找第 1 小(右子树)
左子树 [3,3] 元素数 = 1
1 >= 1,第 1 小在 [3,3]
结果: 3(对应原数组值 4
注意事项:需要先对原数组离散化,将值域压缩到 [1, n]。每个版本对应原数组的前缀,root[i] 存储 A[1..i] 的权值分布。

C++ 实现

C++
#include <vector>

class PersistentSegmentTree {
private:
    struct Node {
        int count;
        int left, right;
        Node() : count(0), left(0), right(0) {}
    };
    
    std::vector<Node> nodes;
    std::vector<int> roots;
    int n;
    
    int newNode() {
        nodes.push_back(Node());
        return nodes.size() - 1;
    }
    
    int build(int left, int right) {
        int node = newNode();
        if (left == right) return node;
        
        int mid = (left + right) / 2;
        nodes[node].left = build(left, mid);
        nodes[node].right = build(mid + 1, right);
        return node;
    }
    
    int insert(int prev, int left, int right, int pos, int value) {
        int node = newNode();
        nodes[node] = nodes[prev];
        nodes[node].count += value;
        
        if (left == right) return node;
        
        int mid = (left + right) / 2;
        if (pos <= mid) {
            nodes[node].left = insert(nodes[prev].left, left, mid, pos, value);
        } else {
            nodes[node].right = insert(nodes[prev].right, mid + 1, right, pos, value);
        }
        
        return node;
    }
    
    int queryKth(int rootLeft, int rootRight, int left, int right, int k) {
        if (left == right) return left;
        
        int mid = (left + right) / 2;
        int count = nodes[nodes[rootRight].left].count - 
                   nodes[nodes[rootLeft].left].count;
        
        if (count >= k) {
            return queryKth(nodes[rootLeft].left, nodes[rootRight].left, 
                           left, mid, k);
        } else {
            return queryKth(nodes[rootLeft].right, nodes[rootRight].right,
                           mid + 1, right, k - count);
        }
    }
    
public:
    PersistentSegmentTree(int size) : n(size) {
        roots.push_back(build(1, n));
    }
    
    void insert(int pos, int value) {
        roots.push_back(insert(roots.back(), 1, n, pos, value));
    }
    
    int kthSmallest(int l, int r, int k) {
        return queryKth(roots[l - 1], roots[r], 1, n, k);
    }
    
    int versionCount() { return roots.size(); }
};

时间复杂度

操作 时间复杂度 空间复杂度 说明
构建 O(n log n) O(n log n) n 次插入
单次插入 O(log n) O(log n) 创建 log n 个节点
区间查询 O(log n) O(1)
区间第 K 小 O(log n) O(1) 两次查询的差

空间复杂度分析

初始构建: n 个节点(空树)
每次插入: log n 个新节点
n 次插入后: n + n log n = O(n log n) 个节点
每个节点存储: count + left + right = O(1)
总空间: O(n log n)

应用场景

应用领域 具体问题
区间第 K 小 静态/动态区间第 K 小值
区间数颜色 区间不同数字个数
历史版本查询 回溯到任意历史状态
树上路径查询 结合树链剖分
可持久化并查集 历史连通性查询

主席树 vs 其他方法

方法 区间第 K 小 空间 预处理
排序 O(n log n) 每次 O(n)
划分树 O(log n) O(n log n) O(n log n)
主席树 O(log n) O(n log n) O(n log n)
莫队+分块 O(√n) O(n) O(n)
选择建议:静态区间第 K 小优先选择主席树,实现简单、查询高效。需要在线查询时主席树是最佳选择。

扩展:动态主席树

支持修改操作,结合树状数组:

C++
// 树状数组套主席树
class DynamicPersistentSegmentTree {
    vector<PersistentSegmentTree> bits;
    
    void update(int i, int pos, int value) {
        for (; i <= n; i += i & (-i)) {
            bits[i].insert(pos, value);
        }
    }
    
    int queryKth(int l, int r, int k) {
        // 利用树状数组的前缀和性质
        // ...
    }
};

参考资料

  • 黄嘉泰(主席)《可持久化数据结构研究》
  • 《算法竞赛进阶指南》可持久化数据结构章节
  • 《数据结构》可持久化线段树