斜率优化dp小结

来源:互联网 发布:ssh 知乎 编辑:程序博客网 时间:2024/06/08 08:26

最近刷了几道斜率优化dp算是对斜率优化有了一定的了解了

现在来小结一下

斜率优化dp的优化能力是将n^2优化成n,n^3优化成n^2

其转移方程一般是dp[k]=min(dp[i]+cost[i+1][k]),斜率优化优化的是排除一些不可能是最优解的解,那么什么情况下不可能是最优解呢



下面看一道题HDU 2829

状态转移方程是dp[i][j]=min(dp[k][j-1]+cost[k+1][i])

那么k2比k1更优的情况是

dp[k1][j-1]+cost[k1+1][i]>dp[k2][j-1]+cost[k2+1][i]

我们发现cost[k+1][i]=cost[1][i]-cost[1][k]-sum[k]*(sum[i]-sum[k])

那么我们带进去化简就可以得到

dp[k1][j-1]-cost[1][k1]+sum[k1]*sum[k1]-sum[k1]*sum[i]>dp[k2][j-1]-cost[1][k2]+sum[k2]*sum[k2]-sum[k2]*sum[i]

那么我们设

y1=dp[k1][j-1]-cost[1][k1]+sum[k1]*sum[k1]

y2=dp[k2][j-1]-cost[1][k2]+sum[k2]*sum[k2]

x1=sum[k1],x2=sum[k2]


所以k2比k1更优的条件是

y2-y1<(x2-x1)*sum[i]

如果等于那就是一样优,那就也是可以删掉k1的那么我们可以在条件上加一个=号

也就是(y2-y1)<=(x2-x1)*sum[i]


假设我有三个数k1,k2,k3,  k1<k2<k3

那么k2永远都不会是最优的情况是

(y3-y2)/(x3-x2)<(y2-y1)/(x2-x1)

我们分类讨论一下如果对于sum[i]来说

如果(y3-y2)/(x3-x2)<=sum[i]那么k3比k2优或者一样优,那么都是可以删掉的

如果(y3-y2)/(x3-x2)>sum[i],此时(y2-y1)/(x2-x1)>sum[i]那么k1比k2优



#include <iostream>#include <cstring>#include <cstdio>#include <algorithm>#define maxn 1500using namespace std;int sum[maxn],dp[maxn][maxn],cost[maxn];int que[maxn];int DP(int n,int m){    for (int k=1;k<=n;k++){        dp[k][k-1]=0;        dp[k][0]=cost[k];    }//初始化完成    int head,tail;    for (int j=1;j<=m;j++){        head=tail=0;
//我们优化的是最后一个遍历所以对于每一个i我们都有一个新的que
//对于一个新的j就是说我们分成j段        //我要遍历的第一个k就是dp[j][j-1]        //所以我们把j入队即可
que[tail++]=j; for (int i=j+1;i<=n;i++){ while(head+1<tail){ int q1=que[head],q2=que[head+1]; int y1=dp[q1][j-1]-cost[q1]+sum[q1]*sum[q1]; int y2=dp[q2][j-1]-cost[q2]+sum[q2]*sum[q2]; int x1=sum[q1],x2=sum[q2]; if ((y2-y1)<=sum[i]*(x2-x1)) head++; else break; }
    //利用sum[i]从头到尾找最优
    //下面解释下为什么这个是最优的,可以看到前面的已经是最优了
   //我们发现他的斜率是单调递减的,也就是说此时head到后面任意一点的斜率永远>sum[i],所以head是最优的情况            int k=que[head];            dp[i][j]=dp[k][j-1]+cost[i]-cost[k]-sum[k]*(sum[i]-sum[k]);            while(head+1<tail){                int q1=que[tail-2],q2=que[tail-1],q3=i;                int y3=dp[q3][j-1]-cost[q3]+sum[q3]*sum[q3];                int y2=dp[q2][j-1]-cost[q2]+sum[q2]*sum[q2];                int y1=dp[q1][j-1]-cost[q1]+sum[q1]*sum[q1];                int x3=sum[q3],x2=sum[q2],x1=sum[q1];                if (((y3-y2)*(x2-x1))<=((y2-y1)*(x3-x2)))  tail--;                else break;            }            que[tail++]=i;
   //因为我们要遍历到的下一个节点是dp[i+1][j],他的k的取值范围0-i所以我们要把i入队
   //在入队的时候我们就要淘汰y2,这样子不仅淘汰了非最优解还维护了队列斜率的单调递减性质        }    }    return dp[n][m];}int main(){    int m,n;    while(~scanf("%d %d",&n,&m)&&(n&&m)){        sum[0]=0;cost[0]=0;        for (int k=1;k<=n;k++){            int save;            scanf("%d",&save);            sum[k]=sum[k-1]+save;            cost[k]=cost[k-1]+sum[k-1]*save;        }        printf("%d\n",DP(n,m));    }    return 0;}