琐碎的区间(线段树区间更新 + 技巧!)

来源:互联网 发布:gre单词书推荐 知乎 编辑:程序博客网 时间:2024/06/05 07:16

琐碎的区间

时间限制: 4 Sec  内存限制: 256 MB
提交: 131  解决: 26
[提交][状态][讨论版]

题目描述

给出一个长度为 n 的整数序列 A[1..n],有三种操作: 
1 l r x :  把[l, r]区间的每个数都加上 x 
2 l r :  把[l, r]  区间每个 A[i]变为sqrt(a[i])的整数部分
3 l r :  求[l, r]  区间所有数的和 
其中 l 和 r 和 x 都代表一个整数 

输入

第一行一个 T,表示数据组数。 
对于每组数据 
Line1:两个数 n m,表示整数序列长度和操作数 
Line2:n 个数,表示 A[1..n] 
Line3…Line3+m-1:每行一个询问,对于第三种询问,请输出答案。 
对于每一种询问,先给出操作的编号,再给出相应的操作,编号与题目描述对应。 
数据约定: 
1<=T<=5 
n,m <= 100000 
1<= A[i], x<=100000 

输出

对于第三种询问,输出答案。每个答案占一行。 

样例输入

15 51 2 3 4 51 3 5 22 1 43 2 42 3 53 1 5

样例输出

56

提示

来源

[提交][状态]

线段树,然后在记录当前区间的最大最小值,如果一样了的话以后更新就直接对区间更新就好了
#include <bits/stdc++.h>   #define ll long longconst int maxn=1e5+10;const int N=2e5+10;using namespace std;#define ls rt<<1#define rs rt<<1|1int n,m,k,t,a[maxn];ll tag[maxn<<2],ma[maxn<<2],mi[maxn<<2],sum[maxn<<2];const int BufferSize=1<<16;char buffer[BufferSize],*head,*tail;inline char Getchar() {    if(head==tail) {        int l=fread(buffer,1,BufferSize,stdin);        tail=(head=buffer)+l;    }    return *head++;}inline int read() {    int x=0,f=1;char c=Getchar();    for(;!isdigit(c);c=Getchar()) if(c=='-') f=-1;    for(;isdigit(c);c=Getchar()) x=x*10+c-'0';    return x*f;}void pup(int l,int r,int rt){    int mid=l+r>>1;    sum[rt]=sum[ls]+sum[rs];    ma[rt]=max(ma[ls],ma[rs]);    mi[rt]=min(mi[ls],mi[rs]);    tag[rt]=0;}void pdw(int l,int r,int rt){    int mid=l+r>>1;    sum[ls]+=tag[rt]*(mid-l+1);    ma[ls]+=tag[rt];    mi[ls]+=tag[rt];    tag[ls]+=tag[rt];    sum[rs]+=tag[rt]*(r-mid);    ma[rs]+=tag[rt];    mi[rs]+=tag[rt];    tag[rs]+=tag[rt];    tag[rt]=0;}void build(int l,int r,int rt){    if(l==r)    {        tag[rt]=0;        ma[rt]=mi[rt]=sum[rt]=a[l];        return;    }    int mid=l+r>>1;    build(l,mid,ls);    build(mid+1,r,rs);    pup(l,r,rt);}void upd(int L,int R,ll v,int l,int r,int rt){    if(L<=l&&r<=R)    {        sum[rt]+=v*(r-l+1);        ma[rt]+=v;        mi[rt]+=v;        tag[rt]+=v;        return;    }    int mid=l+r>>1;    if(tag[rt])pdw(l,r,rt);    if(L<=mid)upd(L,R,v,l,mid,ls);    if(R>mid)upd(L,R,v,mid+1,r,rs);    pup(l,r,rt);}void qsqrt(int L,int R,int l,int r,int rt){    if(L<=l&&r<=R)    {        if(ma[rt]==mi[rt])        {            tag[rt]-=ma[rt];            ma[rt]=sqrt(ma[rt]);            tag[rt]+=ma[rt];            mi[rt]=ma[rt];            sum[rt]=(r-l+1)*ma[rt];            return;        }        else if(ma[rt]==mi[rt]+1)        {            if((ll)sqrt(ma[rt])==(ll)sqrt(mi[rt])+1)            {                tag[rt]-=ma[rt];                sum[rt]-=(r-l+1)*(ma[rt]-(ll)sqrt(ma[rt]));                ma[rt]=sqrt(ma[rt]);                tag[rt]+=ma[rt];                mi[rt]=ma[rt]-1;                return;            }        }    }    int mid=l+r>>1;    if(tag[rt])pdw(l,r,rt);    if(L<=mid)qsqrt(L,R,l,mid,ls);    if(R>mid)qsqrt(L,R,mid+1,r,rs);    pup(l,r,rt);}ll gao(int L,int R,int l,int r,int rt){    if(L<=l&&r<=R)return sum[rt];    int mid=l+r>>1;    if(tag[rt])pdw(l,r,rt);    ll ret=0;    if(L<=mid)ret+=gao(L,R,l,mid,ls);    if(R>mid)ret+=gao(L,R,mid+1,r,rs);    return ret;}int main(){    t=read();    while(t--)    {        n=read();m=read();        for(int i=1;i<=n;i++)            a[i]=read();        build(1,n,1);        while(m--)        {            int b,c,d,e;            b=read(),c=read(),d=read();            if(b==1)            {                e=read();                upd(c,d,e,1,n,1);            }            else if(b==2)            {                qsqrt(c,d,1,n,1);            }            else printf("%lld\n",gao(c,d,1,n,1));        }    }    return 0;}

PS超时的代码:
#include <bits/stdc++.h>using namespace std;const int BufferSize=1<<16;char buffer[BufferSize],*head,*tail;inline char Getchar() {    if(head==tail) {        int l=fread(buffer,1,BufferSize,stdin);        tail=(head=buffer)+l;    }    return *head++;}inline int read() {    int x=0,f=1;char c=Getchar();    for(;!isdigit(c);c=Getchar()) if(c=='-') f=-1;    for(;isdigit(c);c=Getchar()) x=x*10+c-'0';    return x*f;}#define LL long long#define L(root) ((root) << 1)#define R(root) (((root) << 1) | 1)const int MAXN = 1e5 + 5;int numbers[MAXN];//LL delay[MAXN * 4], sum[MAXN * 4], mx[MAXN * 4], mn[MAXN * 4];struct Node {    int left, right;    LL delay;    LL sum;    LL mx, mn;    int mid()    {        return left + ((right - left) >> 1);    }} tree[MAXN * 4];void pushUp(int root){    tree[root].sum = tree[L(root)].sum + tree[R(root)].sum;    tree[root].mx = max(tree[L(root)].mx, tree[R(root)].mx);    tree[root].mn = min(tree[L(root)].mn, tree[R(root)].mn);    tree[root].delay = 0;}void pushDown(int root, int l, int r){        LL mid = (r + l) >> 1;        tree[L(root)].delay += tree[root].delay;        tree[R(root)].delay += tree[root].delay;        tree[L(root)].sum += tree[root].delay * (mid - l + 1);        tree[R(root)].sum += tree[root].delay * (r - mid);        tree[L(root)].mx += tree[root].delay;        tree[R(root)].mx += tree[root].delay;        tree[L(root)].mn += tree[root].delay;        tree[R(root)].mn += tree[root].delay;        tree[root].delay = 0;}void build(int root, int left, int right){    tree[root].left = left;    tree[root].right = right;    if (left == right) {        tree[root].delay = 0;        tree[root].sum = numbers[left];        tree[root].mx = numbers[left];        tree[root].mn = numbers[left];        return;    }    int mid = tree[root].mid();    build(L(root), left, mid);    build(R(root), mid + 1, right);    pushUp(root);}LL query(int root, int left, int right){    if (tree[root].left == left && tree[root].right == right) {        return tree[root].sum;    }    if (tree[root].delay) pushDown(root, tree[root].left, tree[root].right);    int mid = tree[root].mid();    if (right <= mid) {        return query(L(root), left, right);    } else if (left > mid) {        return query(R(root), left, right);    } else {        return query(L(root), left, mid) + query(R(root), mid + 1, right);    }}void update(int root, int left, int right, LL add){    if (tree[root].left == left && tree[root].right == right) {        tree[root].delay += add;        tree[root].sum += add * (right - left + 1);        tree[root].mx += add;        tree[root].mn += add;        return;    }    if (tree[root].delay) pushDown(root, tree[root].left, tree[root].right);    int mid = tree[root].mid();    if (right <= mid) {        update(L(root), left, right, add);    } else if (left > mid) {        update(R(root), left, right, add);    } else {        update(L(root), left, mid, add);        update(R(root), mid + 1, right, add);    }    pushUp(root);}void sq(int root, int left, int right){    if (tree[root].left == left && tree[root].right == right) {        LL mx = sqrt(tree[root].mx);        LL mn = sqrt(tree[root].mn);        if (tree[root].mx == tree[root].mn) {            tree[root].delay -= (tree[root].mx - mx);            tree[root].sum = mx * (right - left + 1);            tree[root].mx = mx;            tree[root].mn = mn;            return;//        }        } else if ((tree[root].mx == tree[root].mn + 1) && (mx == mn + 1)) {            tree[root].delay -= (tree[root].mx - mx);            tree[root].sum -= (tree[root].mx - mx) * (right - left + 1);            tree[root].mx = mx;            tree[root].mn = mn;        }    }    if (tree[root].delay) pushDown(root, tree[root].left, tree[root].right);    int mid = tree[root].mid();    if (right <= mid) {        sq(L(root), left, right);    } else if (left > mid) {        sq(R(root), left, right);    } else {        sq(L(root), left, mid);        sq(R(root), mid + 1, right);    }    pushUp(root);}int main(){    int t;    int n, m;    int i;    int op, l, r, x;//    scanf("%d", &t);    t = read();    while (t--) {//        scanf("%d%d", &n, &m);        n = read(), m = read();        for (i = 1; i <= n; ++i) {//            scanf("%d", &numbers[i]);            numbers[i] = read();        }        build(1, 1, n);        for (i = 0; i < m; ++i) {//            scanf("%d", &op);            op = read();            if (op == 1) {//                scanf("%d%d%d", &l, &r, &x);                l = read(), r = read(), x = read();                update(1, l, r, x);                //printf("debug op = 1\n");            } else if (op == 2) {//                scanf("%d%d", &l, &r);                l = read(), r = read();                sq(1, l, r);                //printf("debug op = 2\n");            } else {//                scanf("%d%d", &l, &r);                l = read(), r = read();                printf("%lld\n", query(1, l, r));                //printf("debug op = 3\n");            }        }    }    return 0;}int main2(){    int t;    int n, m;    int i;    int op, l, r, x;    scanf("%d", &t);//    t = read();    while (t--) {        scanf("%d%d", &n, &m);//        n = read(), m = read();        for (i = 1; i <= n; ++i) {            scanf("%d", &numbers[i]);//            numbers[i] = read();        }        build(1, 1, n);        for (i = 0; i < m; ++i) {            scanf("%d", &op);//            op = read();            if (op == 1) {                scanf("%d%d%d", &l, &r, &x);//                l = read(), r = read(), x = read();                update(1, l, r, x);                //printf("debug op = 1\n");            } else if (op == 2) {                scanf("%d%d", &l, &r);//                l = read(), r = read();                sq(1, l, r);                //printf("debug op = 2\n");            } else {                scanf("%d%d", &l, &r);//                l = read(), r = read();                printf("%lld\n", query(1, l, r));                //printf("debug op = 3\n");            }        }    }    return 0;}




0 0
原创粉丝点击