跳转至

树状数组

概述

树状数组(Binary Indexed Tree,BIT),又称二叉索引树或 Fenwick Tree,是一种支持单点修改和区间查询的高效数据结构。它通过巧妙的二进制分解,将线性数组组织成树形结构,实现了 O(log n) 的更新和查询效率。

核心优势:代码极其简洁(核心仅需10行左右),常数因子小,实际运行效率往往优于线段树。适用于单点修改+区间查询、区间修改+单点查询等场景。

Lowbit 函数详解

定义

Lowbit 函数返回一个数二进制表示中最低位 1 代表的值:

C
1
2
3
int lowbit(int x) {
    return x & (-x);
}

原理解析

利用补码特性:-x = ~x + 1

计算示例:x = 12

x  = 12 = 00001100 (二进制)
~x =    = 11110011
-x =    = 11110100 (补码)

x & (-x) = 00001100 & 11110100 = 00000100 = 4

Lowbit 值表

x 二进制 lowbit(x) 含义
1 0001 1 最低位是第0位
2 0010 2 最低位是第1位
3 0011 1 最低位是第0位
4 0100 4 最低位是第2位
5 0101 1 最低位是第0位
6 0110 2 最低位是第1位
7 0111 1 最低位是第0位
8 1000 8 最低位是第3位

可视化

graph TB
    subgraph Lowbit原理
        A["x = 12 (00001100)"]
        B["-x = -12 (11110100)"]
        C["x & -x = 4 (00000100)"]
        A --> D["& 运算"]
        B --> D
        D --> C
    end
    
    style A fill:#E3F2FD
    style B fill:#E3F2FD
    style C fill:#E8F5E9

树状数组结构

结构定义

树状数组将原数组 A 映射到辅助数组 C,其中:

  • C[i] 存储 A 数组中从 i - lowbit(i) + 1i 的元素和
  • C[i] = A[i - lowbit(i) + 1] + ... + A[i]

结构可视化

假设原数组 A = [1, 3, 5, 7, 9, 11, 13, 15](索引从1开始):

Text Only
原数组 A:
索引:   1   2   3   4   5   6   7   8
值:    [1] [3] [5] [7] [9][11][13][15]

树状数组 C:
C[1] = A[1]                    = 1
C[2] = A[1] + A[2]             = 4
C[3] = A[3]                    = 5
C[4] = A[1] + A[2] + A[3] + A[4] = 16
C[5] = A[5]                    = 9
C[6] = A[5] + A[6]             = 20
C[7] = A[7]                    = 13
C[8] = A[1] + ... + A[8]       = 64

树形结构图

graph TB
    subgraph 树状数组结构
        C8(("C[8]=64<br/>A[1..8]"))
        C4(("C[4]=16<br/>A[1..4]"))
        C2(("C[2]=4<br/>A[1..2]"))
        C1(("C[1]=1<br/>A[1]"))
        C3(("C[3]=5<br/>A[3]"))
        C6(("C[6]=20<br/>A[5..6]"))
        C5(("C[5]=9<br/>A[5]"))
        C7(("C[7]=13<br/>A[7]"))
        
        C8 --> C4
        C8 --> C6
        C4 --> C2
        C4 --> C3
        C2 --> C1
        C6 --> C5
        C6 --> C7
    end
    
    style C8 fill:#4CAF50,color:#fff
    style C4 fill:#2196F3,color:#fff
    style C6 fill:#2196F3,color:#fff

覆盖关系图

Text Only
索引 i  | lowbit(i) | C[i] 覆盖范围 | 管理的A元素
--------|-----------|--------------|------------------
   1    |     1     |    [1, 1]    | A[1]
   2    |     2     |    [1, 2]    | A[1], A[2]
   3    |     1     |    [3, 3]    | A[3]
   4    |     4     |    [1, 4]    | A[1], A[2], A[3], A[4]
   5    |     1     |    [5, 5]    | A[5]
   6    |     2     |    [5, 6]    | A[5], A[6]
   7    |     1     |    [7, 7]    | A[7]
   8    |     8     |    [1, 8]    | A[1]..A[8]

ASCII 内存布局

Text Only
        索引:  0   1   2   3   4   5   6   7   8
              ┌───┬───┬───┬───┬───┬───┬───┬───┬───┐
原数组 A:     │ - │ 1 │ 3 │ 5 │ 7 │ 9 │11 │13 │15 │
              └───┴───┴───┴───┴───┴───┴───┴───┴───┘
              ┌───┬───┬───┬───┬───┬───┬───┬───┬───┐
树状数组 C:   │ - │ 1 │ 4 │ 5 │16 │ 9 │20 │13 │64 │
              └───┴───┴───┴───┴───┴───┴───┴───┴───┘
                ↑   ↑   ↑   ↑   ↑   ↑   ↑   ↑   ↑
               不用 │   │   │   │   │   │   │   │
                   └───┼───┘   │   │   │   │   │
                      C[2]    └───┼───┘   │   │
                     覆盖2个      C[4]    └───┼───┘
                     元素       覆盖4个      C[8]
                                元素        覆盖8个
                                           元素

核心操作详解

单点更新

更新 A[i] 时,需要更新所有包含 A[i] 的 C[j],即 j = i, i + lowbit(i), ...

C
1
2
3
4
5
6
void update(int tree[], int n, int index, int value) {
    while (index <= n) {
        tree[index] += value;
        index += lowbit(index);  // 跳到下一个需要更新的节点
    }
}

更新过程示例:更新 A[3] 加 2

flowchart TD
    A["初始: index = 3"]
    B["C[3] += 2"]
    C["index = 3 + lowbit(3) = 4"]
    D["C[4] += 2"]
    E["index = 4 + lowbit(4) = 8"]
    F["C[8] += 2"]
    G["index = 8 + lowbit(8) = 16 > n"]
    H["结束"]
    
    A --> B --> C --> D --> E --> F --> G --> H
    
    style A fill:#E3F2FD
    style B fill:#4CAF50,color:#fff
    style D fill:#4CAF50,color:#fff
    style F fill:#4CAF50,color:#fff
    style H fill:#E8F5E9
Text Only
1
2
3
4
5
6
7
8
9
更新 A[3] += 2 的过程:

步骤1: index = 3, C[3] += 2 (5→7)
步骤2: index = 3 + 1 = 4, C[4] += 2 (16→18)
步骤3: index = 4 + 4 = 8, C[8] += 2 (64→66)
步骤4: index = 8 + 8 = 16 > 8, 结束

更新后:
C[3] = 7, C[4] = 18, C[8] = 66

前缀查询

查询 A[1] + A[2] + ... + A[i],通过累加 C 数组的特定元素实现:

C
1
2
3
4
5
6
7
8
int query(int tree[], int index) {
    int sum = 0;
    while (index > 0) {
        sum += tree[index];
        index -= lowbit(index);  // 跳到前一个区间
    }
    return sum;
}

查询过程示例:查询 sum[1..7]

flowchart TD
    A["初始: sum = 0, index = 7"]
    B["sum += C[7] = 13"]
    C["index = 7 - 1 = 6"]
    D["sum += C[6] = 13 + 20 = 33"]
    E["index = 6 - 2 = 4"]
    F["sum += C[4] = 33 + 16 = 49"]
    G["index = 4 - 4 = 0"]
    H["结束, 返回 49"]
    
    A --> B --> C --> D --> E --> F --> G --> H
    
    style A fill:#E3F2FD
    style B fill:#2196F3,color:#fff
    style D fill:#2196F3,color:#fff
    style F fill:#2196F3,color:#fff
    style H fill:#E8F5E9
Text Only
1
2
3
4
5
6
7
8
9
查询 sum[1..7] 的过程:

步骤1: index = 7, sum += C[7] = 13
步骤2: index = 7 - 1 = 6, sum += C[6] = 13 + 20 = 33
步骤3: index = 6 - 2 = 4, sum += C[4] = 33 + 16 = 49
步骤4: index = 4 - 4 = 0, 结束

结果: sum[1..7] = C[7] + C[6] + C[4] = 13 + 20 + 16 = 49
验证: A[1]+...+A[7] = 1+3+5+7+9+11+13 = 49 ✓

区间查询

利用前缀和思想,区间 [l, r] 的和等于 sum[1..r] - sum[1..l-1]:

C
1
2
3
int rangeQuery(int tree[], int l, int r) {
    return query(tree, r) - query(tree, l - 1);
}

查询过程示例:查询 sum[3..6]

Text Only
1
2
3
4
5
6
7
sum[3..6] = sum[1..6] - sum[1..2]

sum[1..6] = C[6] + C[4] = 20 + 16 = 36
sum[1..2] = C[2] = 4

结果: sum[3..6] = 36 - 4 = 32
验证: A[3] + A[4] + A[5] + A[6] = 5 + 7 + 9 + 11 = 32 ✓

完整实现

数据结构定义

C
1
2
3
4
5
6
7
8
#include <stdlib.h>
#include <string.h>

typedef struct {
    int *tree;    // 树状数组
    int *arr;     // 原数组(用于差值计算)
    int n;        // 数组大小
} BIT;

创建与初始化

C
BIT* createBIT(int n) {
    BIT *bit = (BIT*)malloc(sizeof(BIT));
    bit->n = n;
    bit->tree = (int*)calloc(n + 1, sizeof(int));
    bit->arr = (int*)calloc(n + 1, sizeof(int));
    return bit;
}

void initBIT(BIT *bit, int arr[], int n) {
    for (int i = 1; i <= n; i++) {
        bit->arr[i] = arr[i - 1];
        update(bit->tree, n, i, arr[i - 1]);
    }
}

完整操作接口

C
void updateBIT(BIT *bit, int index, int value) {
    int delta = value - bit->arr[index];
    bit->arr[index] = value;
    update(bit->tree, bit->n, index, delta);
}

int queryBIT(BIT *bit, int index) {
    return query(bit->tree, index);
}

int rangeQueryBIT(BIT *bit, int l, int r) {
    return rangeQuery(bit->tree, l, r);
}

void destroyBIT(BIT *bit) {
    free(bit->tree);
    free(bit->arr);
    free(bit);
}

O(n) 建树优化

C
1
2
3
4
5
6
7
8
9
void buildBIT(int tree[], int arr[], int n) {
    for (int i = 1; i <= n; i++) {
        tree[i] += arr[i];
        int j = i + lowbit(i);
        if (j <= n) {
            tree[j] += tree[i];
        }
    }
}
建树优化:传统建树需要 O(n log n),利用父子关系可优化到 O(n)。对于每个节点,将其值累加到父节点。

区间更新 + 单点查询

利用差分思想,将"区间更新+单点查询"转化为"单点更新+前缀查询":

C
typedef struct {
    int *diff;    // 差分数组的树状数组
    int n;
} BITDiff;

BITDiff* createBITDiff(int n) {
    BITDiff *bit = (BITDiff*)malloc(sizeof(BITDiff));
    bit->n = n;
    bit->diff = (int*)calloc(n + 1, sizeof(int));
    return bit;
}

// 区间 [l, r] 每个元素加 value
void rangeUpdateDiff(BITDiff *bit, int l, int r, int value) {
    // 差分数组: D[l] += value, D[r+1] -= value
    int index = l;
    while (index <= bit->n) {
        bit->diff[index] += value;
        index += lowbit(index);
    }
    
    index = r + 1;
    while (index <= bit->n) {
        bit->diff[index] -= value;
        index += lowbit(index);
    }
}

// 查询单点值(差分前缀和)
int pointQueryDiff(BITDiff *bit, int index) {
    int sum = 0;
    while (index > 0) {
        sum += bit->diff[index];
        index -= lowbit(index);
    }
    return sum;
}

差分原理可视化

Text Only
原数组 A:           [1,  3,  5,  7,  9, 11, 13, 15]
差分数组 D:         [1,  2,  2,  2,  2,  2,  2,  2]
                       ↑                       ↑
                    D[1]=A[1]              D[i]=A[i]-A[i-1]

区间 [3, 6] 加 10:
修改差分数组: D[3] += 10, D[7] -= 10

差分数组变为:       [1,  2, 12,  2,  2,  2, -8,  2]

还原数组:
A[1] = D[1] = 1
A[2] = D[1] + D[2] = 3
A[3] = D[1] + D[2] + D[3] = 15  (原值5+10=15) ✓
A[4] = ... = 17  (原值7+10=17) ✓
A[5] = ... = 19  (原值9+10=19) ✓
A[6] = ... = 21  (原值11+10=21) ✓
A[7] = ... = 13  (不变) ✓

C++ 模板实现

C++
template<typename T>
class BIT {
private:
    std::vector<T> tree;
    int n;
    
public:
    BIT(int size) : n(size), tree(size + 1, 0) {}
    
    void update(int index, T value) {
        while (index <= n) {
            tree[index] += value;
            index += index & (-index);
        }
    }
    
    T query(int index) {
        T sum = 0;
        while (index > 0) {
            sum += tree[index];
            index -= index & (-index);
        }
        return sum;
    }
    
    T rangeQuery(int l, int r) {
        return query(r) - query(l - 1);
    }
    
    // 二分查找第k小(需要值域树状数组)
    int findKth(int k) {
        int pos = 0;
        int bitMask = 1;
        while (bitMask <= n) bitMask <<= 1;
        bitMask >>= 1;
        
        for (; bitMask; bitMask >>= 1) {
            int next = pos + bitMask;
            if (next <= n && tree[next] < k) {
                k -= tree[next];
                pos = next;
            }
        }
        return pos + 1;
    }
};

应用实例

1. 逆序对计数

C
long long countInversions(int nums[], int n) {
    // 离散化
    int *sorted = (int*)malloc(sizeof(int) * n);
    memcpy(sorted, nums, sizeof(int) * n);
    // 排序并离散化...
    
    int maxVal = n;  // 离散化后值域为 [1, n]
    int *tree = (int*)calloc(maxVal + 1, sizeof(int));
    long long count = 0;
    
    // 从右向左扫描
    for (int i = n - 1; i >= 0; i--) {
        // 查询比当前数小的数的个数
        int smaller = query(tree, nums[i] - 1);
        count += smaller;
        // 将当前数加入树状数组
        update(tree, maxVal, nums[i], 1);
    }
    
    free(tree);
    free(sorted);
    return count;
}

逆序对原理

flowchart LR
    A["数组: [3, 1, 4, 2]"]
    B["从右向左扫描"]
    C["遇到 2: 无更小数"]
    D["遇到 4: 有1个更小数 2"]
    E["遇到 1: 无更小数"]
    F["遇到 3: 有2个更小数 1,2"]
    G["逆序对数 = 0+1+0+2 = 3"]
    
    A --> B --> C --> D --> E --> F --> G
    
    style G fill:#E8F5E9

2. 动态求第 K 小

C
int findKth(int tree[], int n, int k) {
    int pos = 0;
    int bitMask = 1;
    
    // 找到最大的2的幂次
    while (bitMask <= n) bitMask <<= 1;
    bitMask >>= 1;
    
    // 二分查找
    for (; bitMask; bitMask >>= 1) {
        int next = pos + bitMask;
        if (next <= n && tree[next] < k) {
            k -= tree[next];
            pos = next;
        }
    }
    
    return pos + 1;
}

查找过程示例

Text Only
树状数组表示频率统计: tree[i] 表示值 <= i 的元素个数
查找第 k=5 小的元素

步骤1: bitMask = 8, pos = 0, next = 8
       tree[8] = 7 >= 5, 不跳

步骤2: bitMask = 4, pos = 0, next = 4
       tree[4] = 4 < 5, k = 5-4 = 1, pos = 4

步骤3: bitMask = 2, pos = 4, next = 6
       tree[6] = 5 >= 1, 不跳

步骤4: bitMask = 1, pos = 4, next = 5
       tree[5] = 4 < 1? 否, 不跳

结果: pos + 1 = 5, 第5小的元素是5

3. 区间最值(变种)

C++
class BITMax {
private:
    std::vector<int> tree;
    int n;
    
public:
    BITMax(int size) : n(size), tree(size + 1, INT_MIN) {}
    
    void update(int index, int value) {
        while (index <= n) {
            tree[index] = std::max(tree[index], value);
            index += index & (-index);
        }
    }
    
    int query(int index) {
        int result = INT_MIN;
        while (index > 0) {
            result = std::max(result, tree[index]);
            index -= index & (-index);
        }
        return result;
    }
};

树状数组 vs 线段树

graph TB
    subgraph 树状数组
        A1["代码简洁 (~10行)"]
        A2["空间 O(n)"]
        A3["常数因子小"]
        A4["仅支持前缀操作"]
        A5["难以扩展"]
    end
    
    subgraph 线段树
        B1["代码复杂 (~50行)"]
        B2["空间 O(4n)"]
        B3["常数因子较大"]
        B4["支持任意区间"]
        B5["易于扩展"]
    end
    
    style A1 fill:#4CAF50,color:#fff
    style A3 fill:#4CAF50,color:#fff
    style B4 fill:#4CAF50,color:#fff
    style B5 fill:#4CAF50,color:#fff
特性 树状数组 线段树
代码复杂度 简单(~10行) 复杂(~50行)
空间复杂度 O(n) O(4n)
单点修改 O(log n) O(log n)
区间查询 O(log n) O(log n)
区间修改 需要差分变体 原生支持
可扩展性 较差 强(支持各种操作)
常数因子
适用场景 单点更新+区间查询 复杂区间操作
选择建议:如果只需要单点修改+区间查询,优先选择树状数组(代码简洁、效率高);如果需要区间修改、区间最值等复杂操作,选择线段树。

时间复杂度分析

操作 时间复杂度 说明
建树 O(n log n) 或 O(n) O(n) 建树需要优化
单点更新 O(log n) 最多更新 log n 个节点
前缀查询 O(log n) 最多查询 log n 个节点
区间查询 O(log n) 两次前缀查询
区间更新(差分) O(log n) 两次单点更新

复杂度推导

Text Only
1
2
3
4
5
6
7
8
9
更新操作:
- 每次迭代 index += lowbit(index)
- index 的二进制末尾 0 的数量至少增加 1
- 最多迭代 log₂n 次

查询操作:
- 每次迭代 index -= lowbit(index)
- index 的二进制 1 的数量至少减少 1
- 最多迭代 log₂n 次

空间复杂度

实现 空间复杂度 说明
基本实现 O(n) 仅需 tree 数组
完整实现 O(n) tree + arr 数组
差分实现 O(n) 仅需 diff 数组

二维树状数组

支持矩阵区域更新和查询:

C++
class BIT2D {
private:
    std::vector<std::vector<int>> tree;
    int n, m;
    
public:
    BIT2D(int rows, int cols) : n(rows), m(cols), 
                                tree(rows + 1, std::vector<int>(cols + 1, 0)) {}
    
    void update(int x, int y, int value) {
        for (int i = x; i <= n; i += i & (-i)) {
            for (int j = y; j <= m; j += j & (-j)) {
                tree[i][j] += value;
            }
        }
    }
    
    int query(int x, int y) {
        int sum = 0;
        for (int i = x; i > 0; i -= i & (-i)) {
            for (int j = y; j > 0; j -= j & (-j)) {
                sum += tree[i][j];
            }
        }
        return sum;
    }
    
    int rangeQuery(int x1, int y1, int x2, int y2) {
        return query(x2, y2) - query(x1-1, y2) 
               - query(x2, y1-1) + query(x1-1, y1-1);
    }
};

典型应用场景总结

应用场景 描述 时间复杂度
区间求和 动态维护前缀和 O(log n)
逆序对 计数排序思想 O(n log n)
动态排名 实时统计第K大/小 O(log n)
频率统计 元素计数 O(log n)
区间更新 差分思想 O(log n)
二维区域和 矩阵区域查询 O(log²n)

实际问题示例

LeetCode 307. 区域和检索 - 数组可修改

C++
class NumArray {
private:
    BIT<int> bit;
    vector<int> nums;
    
public:
    NumArray(vector<int>& nums) : bit(nums.size()), nums(nums) {
        for (int i = 0; i < nums.size(); i++) {
            bit.update(i + 1, nums[i]);
        }
    }
    
    void update(int index, int val) {
        int delta = val - nums[index];
        nums[index] = val;
        bit.update(index + 1, delta);
    }
    
    int sumRange(int left, int right) {
        return bit.rangeQuery(left + 1, right + 1);
    }
};

LeetCode 315. 计算右侧小于当前元素的个数

C++
vector<int> countSmaller(vector<int>& nums) {
    int n = nums.size();
    vector<int> result(n);
    
    // 离散化
    vector<int> sorted = nums;
    sort(sorted.begin(), sorted.end());
    sorted.erase(unique(sorted.begin(), sorted.end()), sorted.end());
    
    int maxVal = sorted.size();
    BIT<int> bit(maxVal);
    
    auto rank = [&](int x) {
        return lower_bound(sorted.begin(), sorted.end(), x) 
               - sorted.begin() + 1;
    };
    
    for (int i = n - 1; i >= 0; i--) {
        int r = rank(nums[i]);
        result[i] = bit.query(r - 1);
        bit.update(r, 1);
    }
    
    return result;
}

参考资料

  • Fenwick, P. M. (1994). A New Data Structure for Cumulative Frequency Tables
  • 《算法竞赛入门经典》树状数组章节
  • 《挑战程序设计竞赛》数据结构章节