BZOJ 4505 K个串 主席树标记永久化

来源:互联网 发布:淘宝美版ac68u 编辑:程序博客网 时间:2024/06/15 17:17

Description

    兔子们在玩 k 个串的游戏。首先,它们拿出了一个长度为 n 的数字序列,选出其中的一个连续子串,然后统计其子串中所有数字之和(注意这里重复出现的数字只被统计一次)。
    兔子们想知道,在这个数字序列所有连续的子串中,按照以上方式统计其所有数字之和,第 k 大的和是多少。

Input

第一行,两个整数 nk,分别表示长度为 n 的数字序列和想要统计的第k大的和
接下里一行 n 个数 ai,表示这个数字序列

Output

一行一个整数,表示第k大的和

Sample Input

7 5
3 -2 1 2 2 1 3 -2

Sample Output

4

HINT

1n100000, 1k200000, 0|ai|109 数据保证存在第 k 大的和



Solution:

    初看这道题我也很懵逼啊,始终在 n2 的圈中跳不出来,后来实在是难受看了一下题解,咦,听说有道类似的在 BZOJ 上题叫超级钢琴,做了一做有些思路了,但是这个区间的最值还是很恶心,实在是不行,再次瞄了一眼题解,哈,主席树,讲真主席树按区间建树,再区间查找最值没写过,哈,什么,标记永久化,这,没话说。
    首先我们预处理出当前数前一个与之相同的数的位置,这里我们记一个 pre[i] 表示,当前数 num[i] 对答案的贡献只在 pre[num[i]]+1i 位置,对于区间 LR 的和就为其中 pre[num]>L 的所有的 num 的和。
    然后我们考虑建主席树,对于新建的一颗树实在其上颗树的 pre[num[i]]+1i 加上 num[i] , 由于要涉及到区间加,于是我们标记永久化,记一个 Lazy 标记,表示该区间需要加上 Lazy 的大小的值,记得在统计答案的时候要加上 Lazy 标记, 这样我们就可以用主席树来求区间的和。
    这是我们考虑如何求第 K 大的区间,我们记一个优先队列,优先队列中的每一点存五个信息(这种做法类似于超级钢琴) i,pos,l,r,max 分别表示以 i 为右端点, 左端点在区间 lr 时,在 pos 取得最大值,最大值为 max,然后我们考虑更新后面的状态,由于对于当前区间我们已经用过 pos 的信息了,接下来就不需要了,于是我们把区间拆开,拆为 lpos1pos+1r,然后丢入优先队列,更新。




Code :

#include <cstdio>#include <cstdlib>#include <cstring>#include <string>#include <algorithm>#include <iostream>#include <cmath>#include <ctime>#include <map>#include <queue>#define LL long long#define mp make_pairusing namespace std;inline int read() {    int i = 0, f = 1;    char ch = getchar();    while(!isdigit(ch)) {        if(ch == '-') f = -1; ch = getchar();    }    while(isdigit(ch)) {        i = (i << 3) + (i << 1) + ch - '0'; ch = getchar();    }    return i * f;}const int MAXN = 1e5 + 5;struct point {    int pos, l, r, x;    LL sum;    point() {}    point(int ex, int epos, int el, int er, LL esum) : x(ex), pos(epos), l(el), r(er), sum(esum) {}    inline bool operator <(const point & a) const {        return sum < a.sum;    }};struct node {    LL mx, lazy;    node *lc, *rc;    int pos;} *root[MAXN], pool[MAXN * 40], *cur = pool;int num[MAXN];inline void build(int l, int r, node *&rt) {    rt = cur++, rt->pos = l;    if(l == r) return;    register int mid = l + r >> 1;    build(l, mid, rt->lc), build(mid + 1, r, rt->rc);}inline void get(node *a, node *b) {    a->lc = b->lc, a->rc = b->rc, a->lazy = b->lazy;}inline void insert(node *&rt, node *pre, int l, int r, int s, int t, int d) {    rt = cur++, get(rt, pre);    if(s <= l && r <= t) {        rt->mx = pre->mx + d, rt->lazy += d, rt->pos = pre->pos;        return ;    }    register int mid = l + r >> 1;    if(s <= mid) insert(rt->lc, pre->lc, l, mid, s, t, d);    if(mid < t) insert(rt->rc, pre->rc, mid + 1, r, s, t, d);    if(rt->lc->mx >= rt->rc->mx) {        rt->mx = rt->lc->mx + rt->lazy, rt->pos = rt->lc->pos;    } else {        rt->mx = rt->rc->mx + rt->lazy, rt->pos = rt->rc->pos;    }}inline pair<int, LL> query(node *rt, int l, int r, int s, int t) {    if(s <= l && r <= t) return mp(rt->pos, rt->mx);    register int mid = l + r >> 1;    if(t <= mid) {        pair<int, LL> tmp = query(rt->lc, l, mid, s, t);        tmp.second += rt->lazy;        return tmp;    } else if(s > mid) {        pair<int, LL> tmp = query(rt->rc, mid + 1, r, s, t);        tmp.second += rt->lazy;        return tmp;    } else {        pair<int, LL> tmpl, tmpr;        tmpl = query(rt->lc, l, mid, s, t); tmpr = query(rt->rc, mid + 1, r, s, t);        tmpl.second += rt->lazy, tmpr.second += rt->lazy;        return tmpl.second >= tmpr.second ? tmpl : tmpr;    }}inline void solve() {    int n = read(), k = read();    build(1, n, root[0]);    priority_queue<point> q;    map<int, int> pre;    for(register int i = 1; i <= n; ++i) {        num[i] = read();        int s = pre[num[i]] + 1, t = i, d = num[i];        insert(root[i], root[i - 1], 1, n, s, t, d);        pair<int, LL> tmp = query(root[i], 1, n, 1, t);        q.push(point(i, tmp.first, 1, i, tmp.second));        pre[num[i]] = i;    }    while(--k) {        point a = q.top();        q.pop();        if(a.l < a.pos) {            pair<int, LL> tmp = query(root[a.x], 1, n, a.l, a.pos - 1);            q.push(point(a.x, tmp.first, a.l, a.pos - 1, tmp.second));        }        if(a.r > a.pos) {            pair<int, LL> tmp = query(root[a.x], 1, n, a.pos + 1, a.r);            q.push(point(a.x, tmp.first, a.pos + 1, a.r, tmp.second));        }    }    cout<<q.top().sum;}int main() {    solve();}
原创粉丝点击