[bzoj 4540] [Hnoi2016]序列:离线,线段树,矩阵乘法

来源:互联网 发布:阿里旺旺国际版 mac 编辑:程序博客网 时间:2024/05/22 12:02

题意:给一个长为N的序列,每项绝对值不超过10^9。Q个询问,求某区间的所有子区间的最小值之和。(1 ≤N,Q ≤ 100000)

UPDATE 2017.5.27: 这类离线线段树问题更好的理解方式是数形结合, 把区间[l,r]看成平面上的点(l,r). 见[spoj GSS2] Can you answer these queries II .

考场上写带奇怪记忆化的O(n^2)分治获得20分,也是两天的省选中唯一的20分……

这几天碰到有关“所有子区间”的题,想到这个。以为想到正解,结果是错的……毛爷爷教导我们,Think twice, code once. 做起来不容易。有时代码写到那一行,才察觉先前的想法不够周密或完全错误。

据说莫队也可以做,目前只学了线段树做法。

感谢RicardoWang同学,他的文章写得非常清晰:【BZOJ4540】【Hnoi2016】序列 线段树

待求的式子:

ans=i=lrj=irminA[i..j]

令j=1, 2, …, N,对于每个点i≤j,动态维护j时刻的值:

vali(j)=minA[i..j]

则:

ans=i=lrj=irvali(j)

问题转化为求[l, r]内所有点截止r时刻的历史版本和。

至此,我们设计出这样一个算法:
1. 离线,将询问按照右端点排序。
2. 用单调栈求以每个点作为最小值的区间的左端点。
3. 从左到右扫描整个序列,用线段树维护左边所有点的val及val的历史版本和;如果有询问的右端点等于当前指针,回答。
4. 输出答案。

新加入一个数只会影响包含它的一个连续区间内的val,这为我们提供了便利。

用线段树维护是难点。

询问的是区间内所有点历史版本和之和,记为his。修改,改的是val,记区间val之和为sum。不妨把修改val和更新历史版本分为两步:modify和new_version。这两个操作都得打标记,怎样更新信息和合并标记呢?

观察到这两个操作只涉及加和乘,不涉及取max之类的。记区间长度为len,新的值为v。那么,对于modify:
sum' = v*len

对于new_version:
his' = sum + his

不妨写成矩阵乘法形式:

[lensumhis]100ab0cd1=[lensumhis]

modify的参数为a=v,b=c=d=0,new_version的参数为a=b=c=0,d=1,单位元为a=c=d=0,b=1。

his的更新用到sum,这个sum是变换前的sum。

由于矩阵乘法有结合律,标记的合并是很简便的。

几种标记的c均等于0,为什么还要维护?为什么第一列和第三行就不用维护?

显然第一列是不变的,因为它对应区间长度。第三行的意义是取第三行作为新的第三行(第一列也可以这样分析)。c是行向量1和列向量3的内积,是会变的。

实现的时候可以把tag写成一个类,非常清爽~

这道题带来两个收获:
1. 把和式写成区间查询的形式,看看我们该维护什么。
2. 可以把信息叠加到一个点上,看作一个个历史版本。就像可持久化线段树求区间第k大,把前缀区间的信息依次加入,看作一棵线段树在不同时刻的形态。
3. 在线段树等数据结构上打lazy tag,根本含义是,子树(不包括根)依次执行这些操作得到“实时”信息。如果标记易于合并,并且结合标记能快速计算出信息,即可放心使用。标记的合并就是运算的复合,它不一定是一个数,还可以是矩阵、函数——一切能表示“变换”的东西。

#include <cstdio>#include <algorithm>#define ALL 1, 1, ntypedef long long ll;const int MAX_N = 1e5, MAX_Q = 1e5;const ll inf = 1LL<<60;struct Query {    int k, l, r;    bool operator<(const Query& rhs) const    {        return r < rhs.r;    }} Q[MAX_Q];int n, b[MAX_N+1];ll a[MAX_N+1] = {-inf}, ans[MAX_Q];struct Tag {    ll a, b, c, d;    Tag(ll a=0, ll b=1, ll c=0, ll d=0): a(a), b(b), c(c), d(d) {}    void merge(const Tag& y)    {        Tag x(y.a + a*y.b, b*y.b, c + y.c + a*y.d, b*y.d + d);        *this = x;    }    void clean()    {        a = c = d = 0;        b = 1;    }};struct Segment_Tree {    Tag t[MAX_N*4];    ll sum[MAX_N*4], his[MAX_N*4];    void down(int len, const Tag& x, int o)    {        his[o] += x.c*len + x.d*sum[o];        sum[o] = x.a*len + x.b*sum[o];        t[o].merge(x);    }    void down(int o, int l, int r)    {        int m = (l+r)/2;        down(m-l+1, t[o], o*2);        down(r-m, t[o], o*2+1);        t[o].clean();    }    void up(int o)    {        int lc = o*2, rc = o*2+1;        sum[o] = sum[lc] + sum[rc];        his[o] = his[lc] + his[rc];    }    void update(int L, int R, const Tag& x, int o, int l, int r)    {        if (L<=l && r<=R) {            down(r-l+1, x, o);            return;        }        down(o, l, r);        int m = (l+r)/2;        if (L <= m) update(L, R, x, o*2, l, m);        if (R > m) update(L, R, x, o*2+1, m+1, r);        up(o);    }    ll query(int L, int R, int o, int l, int r)    {        if (L<=l && r<=R)            return his[o];        down(o, l, r);        int m = (l+r)/2;        ll ret = 0;        if (L <= m) ret = query(L, R, o*2, l, m);        if (R > m) ret += query(L, R, o*2+1, m+1, r);        return ret;    }    void modify(int L, int R, ll v)    {        update(L, R, Tag(v, 0, 0, 0), ALL);    }    void new_version(int L, int R)    {        update(L, R, Tag(0, 1, 0, 1), ALL);    }} T;int main(){    int q;    scanf("%d %d", &n, &q);    for (int i = 1; i <= n; ++i)        scanf("%lld", &a[i]);    for (int i = 0; i < q; ++i) {        scanf("%d %d", &Q[i].l, &Q[i].r);        Q[i].k = i;    }    std::sort(Q, Q+q);    static ll S[MAX_N];    int top = 0;    for (int i = n; i >= 0; --i) {        while (top && a[i] <= a[S[top-1]])            b[S[--top]] = i+1;        S[top++] = i;    }    for (int i = 1, j = 0; j < q && i <= n; ++i) {        T.modify(b[i], i, a[i]);        T.new_version(1, i);        while (j < q && Q[j].r == i) {            ans[Q[j].k] = T.query(Q[j].l, Q[j].r, ALL);            ++j;        }    }    for (int i = 0; i < q; ++i)        printf("%lld\n", ans[i]);    return 0;}
0 0