JZOJ 5432 三元组

来源:互联网 发布:java 成功生成文件夹 编辑:程序博客网 时间:2024/05/29 04:09

三元组

来自GJX大佬的题目

Description

给出n个三元组(xi,yi,zi)及三个整数X,Y,Z满足X+Y+Z=n
每一个三元组最多只能选择一个数(即xi,yi,zi中的一个)
给出三个要求:
1、选择xi的三元组恰有X个。
2、选择yi的三元组恰有Y个。
3、选择zi的三元组恰有Z个。
问选出的数的和最大为多少。

Data Constraint

1<=n<=51050<=xi,yi,zi<=5105

Solution

考虑若X=0时怎么做,先强制选择所有的yi,然后按照zi-yi排序,选择最大的Z个加入答案中。
因而对于三元组,可以先强制选择所有的xi,然后从n个新的二元组(yi-xi,zi-xi)中选择YyixiZzixi使价值最大。(有X个二元组不选择)

设新的二元组为(vi,ui)。
n个二元组按照vi为第一关键字从大到小排序。
考虑枚举最后选择vi的位置为i,则1~i的位置一定都被选择了,如果不是,则最后选择vi的位置一定在i前面,significantly
使用同样的方法,将1~ivi全部都强制选择,然后从前i个二元组中选择i-Y个最大的ui-vi,从i+1~n个二元组中选择Z-(i-Y)个最大的ui,两块价值加起来与待定答案取max即可。

对于维护一些数中的前k大的和,看到数据范围0<=xi,yi,zi<=5105,可以开两个桶维护区间1~ii+1~n,维护两个指针分别指向第k大的位置,然后可以发现随着i的增加,指针只会往单方向移动,因此时间复杂度可以做到O(n)。

Code

#include<iostream>#include<cstdio>#include<cstring>#include<algorithm>#define fo(i,j,l) for(int i=j;i<=l;i++)#define fd(i,j,l) for(int i=j;i>=l;i--)using namespace std;typedef long long ll;const ll N=55e4,M=2*N;struct note{    ll y,z;}t[N];int X,Y,Z;ll x[N],t1[3*N],t2[3*N];int n,m,j,k,l,i,o,p;int max(int a,int b){if(a>b)return a;else return b;}int read(){    int o=0; char ch=' ';    for(;ch<'0'||ch>'9';)ch=getchar();    for(;ch>='0'&&ch<='9';ch=getchar())o=o*10+ch-48;    return o;}bool kmp(note a,note b){return a.y!=b.y ? a.y>b.y : a.z<b.z ;}int main(){    cin>>X>>Y>>Z; n=X+Y+Z;    fo(i,1,n)x[i]=read(),t[i].y=read(),t[i].z=read();    ll ans=0,da=0;    fo(i,1,n)da+=x[i],t[i].y-=x[i],t[i].z-=x[i];    sort(t+1,t+n+1,kmp);    ll dq=0,op=0,u=0,zd=M+N-10,v=0;    fo(i,1,Y)t1[t[i].z-t[i].y+M]++,dq+=t[i].y;    fo(i,Y+1,n)t2[t[i].z+M]++;    for(;!t1[zd];)zd--;    v=t1[zd]; t1[zd]=0;    int rr=M+N-10,sx=Z;    while(t2[rr]<sx)op+=t2[rr]*(rr-M),sx-=t2[rr],rr--;    u=rr; op+=sx*(rr-M);  t2[rr]=0; ans=op+dq;    fo(i,Y+1,Y+Z){        dq+=t[i].y; ll tt=t[i].z-t[i].y+M;         if(tt>zd)dq+=tt-M;        else if(tt==zd)dq+=tt-M;        else {            t1[tt]++;            dq+=zd-M; v--;            if(!v){                for(;!t1[zd];)zd--;                v=t1[zd]; t1[zd]=0;            }        }        tt=t[i].z+M;        if(tt>u)op-=tt-M,t2[tt]--;        else {            sx--,op-=u-M;            if(!sx){                for(;!t2[u];)u++;                sx=t2[u]; t2[u]=0;            }        }        ans=max(ans,dq+op);    }    printf("%lld",ans+da);}