hdoj 4747 线段树

来源:互联网 发布:网络配线架b的打法 编辑:程序博客网 时间:2024/05/22 02:02

hdoj 4747

题意:给一个序列,问所有区间[i,j](i<=j)中,区间中最小的非负整数加和是多少。

思路:

看样例:

1 0 2 0 1

做出以第一个元素1为起点的分别到其他各点的所有区间最小值:

0 2 3 3 3

发现是一个单调递增序列,实际上对于任何一个样例都是这样的结果,因为随着数字的增加,最小值只会增大。

然后我们删掉第一个数字,做第二个元素0的区间最小值:

x 1 1 1 3

因为我们删掉了0之前的1后,在下一个1出现之前,这段区间中就可能拥有一个更小的值1。

而相比于1的区间最小值,0的区间最小值只修改了两个1中间的这段序列,并且是将所有大于1的数字修改为1。

而又因为这是个单调序列,所以修改的这段范围最多有1个(如果没有比删掉的这个值小的数就不用修改)

所以题目就转化成了线段树区间修改区间查询。

先预处理相同数字的位置邻接表,再找出序列中最左端的大于等于删除掉的数Vi的位置L,然后找到下个Vi出现的位置R,如果没有就是序列的最后位置n,最后将这个区间的值修改为Vi。

#include <cstdio>#include <cstring>#include <map>#include <algorithm>using namespace std;const int M = 200020;struct Tree{    int l, r, val;    int flag;    long long sum;}tree[M * 4];int seq[M], mex[M], Next[M];bool vis[M];map<int, int>mp;int cnt, n;void buildtree(int rt, int l, int r) {    tree[rt].l = l, tree[rt].r = r, tree[rt].flag = false;    if(l == r) {        tree[rt].val = mex[cnt++];        tree[rt].sum = tree[rt].val;    //    printf("%d %d %d %I64d\n", rt, tree[rt].l, tree[rt].r, tree[rt].sum);        return ;    }    int mid = (l + r) / 2;    buildtree(rt * 2, l, mid);    buildtree(rt * 2 + 1, mid + 1, r);    tree[rt].val = max(tree[rt * 2].val, tree[rt * 2 + 1].val);    tree[rt].sum = tree[rt * 2].sum + tree[rt * 2 + 1].sum; //   printf("%d %d %d %I64d\n", rt, tree[rt].l, tree[rt].r, tree[rt].sum);}void pushdown(int rt){    if(tree[rt].flag) {        tree[rt * 2].val = tree[rt * 2 + 1].val = tree[rt].val;        tree[rt * 2].sum = (long long) (tree[rt * 2].r - tree[rt * 2].l + 1) * tree[rt * 2].val;        tree[rt * 2 + 1].sum = (long long) (tree[rt * 2 + 1].r - tree[rt * 2 + 1].l + 1) * tree[rt * 2 + 1].val;        tree[rt * 2].flag = tree[rt * 2 + 1].flag = true;        tree[rt].flag = false;    }}void pullup(int rt) {    tree[rt].sum = tree[rt * 2].sum + tree[rt * 2 + 1].sum;    tree[rt].val = max(tree[rt * 2].val, tree[rt * 2 + 1].val);}long long query(int rt, int l, int r) {    if(tree[rt].l == l && tree[rt].r == r) {        return tree[rt].sum;    }    pushdown(rt);    int mid = (tree[rt].l + tree[rt].r) / 2;    if(l > mid) return query(rt * 2 + 1, l, r);    else if(r <= mid) return query(rt * 2, l, r);    else return query(rt * 2, l, mid) + query(rt * 2 + 1, mid + 1, r);}void update(int rt, int l, int r, int a) {    if(tree[rt].l == l && r == tree[rt].r) {        tree[rt].val = a;        tree[rt].sum = (long long)(tree[rt].r - tree[rt].l + 1) * a;        tree[rt].flag = true;        return ;    }    pushdown(rt);    int mid = (tree[rt].l + tree[rt].r) / 2;    if(l > mid) update(rt * 2 + 1, l, r, a);    else if(r <= mid) update(rt * 2, l, r, a);    else {        update(rt * 2, l, mid, a);        update(rt * 2 + 1, mid + 1, r, a);    }    pullup(rt);}int getLeft(int rt, int val) {    if(tree[rt].l == tree[rt].r) {        if(tree[rt].val >= val) return tree[rt].l;        else return n + 1;    }    pushdown(rt);    if(tree[rt * 2].val >= val) return getLeft(rt * 2, val);    else return getLeft(rt * 2 + 1, val);}int main() {  //  freopen("in.txt", "r", stdin);    while(~scanf("%d", &n) && n) {        mp.clear();        for(int i = 1; i <= n; i++) scanf("%d", &seq[i]);        for(int i = 1; i <= n; i++) Next[i] = n;        memset(vis, 0, sizeof vis);        int val = 0;        for(int i = 1; i <= n; i++) {            if(seq[i] < M) vis[seq[i]] = true;            while(vis[val]) val++;            mex[i] = val;        }        for(int i = 1; i <= n; i++) {            if(mp.find(seq[i]) != mp.end()) Next[mp[seq[i]]] = i - 1;            mp[seq[i]] = i;        }        long long ans = 0;        cnt = 1;        buildtree(1, 1, n);        for(int i = 1; i <= n; i++) {            ans += query(1, i, n);            int left = getLeft(1, seq[i]);        //    printf("[%d, %d] = %d\n", left, Next[i], seq[i]);            if(left <= Next[i]) update(1, left, Next[i], seq[i]);      //      printf("%I64d\n", ans);        }        printf("%I64d\n", ans);    }    return 0;}


0 0