二维树状数组学习笔记

发布于 2019-02-07  341 次阅读


我们把树状数组由一维扩展到二维。二维树状数组的定义是:
$C[x][y] = \sum A[i][j]$,其中
$x - lowbit(x) + 1 \leq i \leq x$
$y - lowbit(y) + 1 \leq j \leq y$
所以我们就可以很方便的写出来单点修改和查询 $(1,1)$ 到 $(x,y)$ 的和的代码了:

inline void add(int x,int y,int delta){
    for(int i = x;i <= n;i += lowbit(i)){
        for(int j = y;j <= m;j += lowbit(j)){
            tree[i][j] += delta;
        }
    }
}

inline int query(int x,int y){
    int res = 0;
    for(int i = x;i >= 1;i -= lowbit(i)){
        for(int j = y;j >= 1;j -= lowbit(j)){
            res += tree[i][j];
        }
    }
    return res;
}

那我们来考虑如何维护这样一个操作:二维矩阵整体加,查询矩阵和。
二维线段树肯定是可以做的,但是写起来太麻烦,甚至可能会爆空间。
我们想一下在一维树状数组里,我们维护区间加是通过差分原数组来实现的。
注意到一维下差分的性质:定义 D 为 A 的差分数组,则 $A_x = \sum_{i=1}^x D_i$
所以说我们可以类比一维,定义二维下差分数组为:$D_{i,j} = A_{i,j}-A{i,j-1}-A_{i-1,j}+A_{i-1,j-1}$。
我们考虑考虑修改操作。通过差分数组的定义不难得出我们修改 $(a,b)$ 到 $(c,d)$ 矩阵时,我们应该在差分数组的 $(a,b)$ 和 $(c+1,d+1)$ 加上 $delta$,同时在 $(a,d+1)$ 和 $(c+1,b)$ 减去 $delta$。
那我们如何处理询问操作呢?
我们不妨先写出最暴力的处理 $(1,1)$ 到 $(x,y)$ 的和,然后逐步优化。
根据差分数组的定义可得:
$$\sum_{i=1}^x \sum_{j=1}^y \sum_{h=1}^i \sum_{k=1}^j d_{h,k}$$
考虑差分数组每一位对答案的贡献,我们可以把其优化为二层枚举:
$$\sum_{i=1}^x \sum_{j=1}^y d_{i,j} * (x-i+1) * (y-j+1)$$
略微拆一下式子,可以得到
$$\sum_{i=1}^x \sum_{j=1}^y d_{i,j} * (x * y+x+y+1) - d_{i,j} * i(y+1) -d_{i,j} * j(x+1) + d_{i,j} * ij$$
用四个树状数组分别维护一下 $d_{i,j},d_{i,j} * i,d_{i,j} * j,d_{i,j} * i * j$ 即可。
一个题目链接
不过貌似这个东西常数比较大,在 Luogu 要开 O2 和读入优化......

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

#define fi first
#define lc (x<<1)
#define se second
#define U unsigned
#define rc (x<<1|1)
#define Re register
#define LL long long
#define MP std::make_pair
#define CLR(i,a) memset(i,a,sizeof(i))
#define FOR(i,a,b) for(Re int i = a;i <= b;++i)
#define ROF(i,a,b) for(Re int i = a;i >= b;--i)
#define SFOR(i,a,b,c) for(Re int i = a;i <= b;i+=c)
#define SROF(i,a,b,c) for(Re int i = a;i >= b;i-=c)
#define DEBUG(x) std::cerr << #x << '=' << x << std::endl

const int MAXN = 2048+5;
int n,m;

inline void read(int &x){
    char ch = getchar();int flag = 0;x = 0;
    while(!isdigit(ch)){
        if(ch == '-') flag = 1;
        ch = getchar();
    }
    while(isdigit(ch)){
        x = (x<<1) + (x<<3) + (ch^'0');
        ch = getchar();
    }
    if(flag) x = -x;
}

struct BIT{
    #define lowbit(x) (x&-x)
    int tree[MAXN][MAXN];

    inline void add(int x,int y,int delta){
        for(int i = x;i <= n;i += lowbit(i)){
            for(int j = y;j <= m;j += lowbit(j)){
                tree[i][j] += delta;
            }
        }
    }

    inline int query(int x,int y){
        int res = 0;
        for(int i = x;i >= 1;i -= lowbit(i)){
            for(int j = y;j >= 1;j -= lowbit(j)){
                res += tree[i][j];
            }
        }
        return res;
    }
    #undef lowbit
}A,Ai,Aj,Aij;

inline int calc(int x,int y){
    return A.query(x,y)*(x*y+x+y+1) - Ai.query(x,y)*(y+1) - Aj.query(x,y)*(x+1) + Aij.query(x,y);
}

inline void add(int x,int y,int num){
    A.add(x,y,num);Ai.add(x,y,num*x);
    Aj.add(x,y,num*y);Aij.add(x,y,num*x*y);
}
int opt[3];
int main(){
    read(n);read(m);
    while(~scanf("%s",opt)){
        int a,b,c,d;read(a);read(b);read(c);read(d);
        if(opt[0] == 'L'){
            int del;read(del);
            add(a,b,del);add(a,d+1,-del);add(c+1,b,-del);add(c+1,d+1,del);
        }
        else printf("%d\n",calc(c,d)-calc(a-1,d)-calc(c,b-1)+calc(a-1,b-1));
    }
    return 0;
}

一个OIer。