RMQ问题--ST算法(Sparse Table)

来源:互联网 发布:68淘宝小号 编辑:程序博客网 时间:2024/05/16 05:56

RMQ(Range Minimum/Maximum Query)问题是求区间最值问题。


这里介绍的ST算法,虽然预处理的复杂度大了点O(nlogn),但是查询复杂度可以降低到O(1)


首先,预处理的时候,是一个动态规划的过程。假设,给出一个数列1, 3, 7, 2, 4, 9, 0,用数组a[]来存储,下标从1开始,便于处理。mx[i, j]表示的是从下标i开始的长度为1<<j的区间最大值,例如,mx[2, 2]就是闭区间[2, 5](3, 7, 2, 4)的最大值,即为7。mi[i, j]同理。我们很容易发现,mx[1, 0] = 1, mx[2, 0] = 3....即mx[i, 0] = a[i]。于是,我们就可以得到状态以及初始值了。接下来,求状态转移方程,当i加上j的一半,不超过数列长度时,我们就可以mx[i, j] = max(mx[i][1<<(j-1)], mx[i+1<<(j-1)+1][1<<(j-1)]),这里有点二分的思想。


然后,就是求最值了。以最大值为例,这里我们仍旧用到了二分的思想。将区间二分一下,然后,返回两者最大值。区间会出现交叉,但是并不影响解题。


#include <cstdio>#include <iostream>#include <algorithm>#include <cmath>using namespace std;const int maxn = 1000;int n, q;int a[maxn], mx[maxn][maxn], mi[maxn][maxn];void stinit() {    for(int i = 1; i <= n; i++) mx[i][0] = mi[i][0] = a[i];    int p = (int)(log(n*1.0)/log(2.0));    for(int i = 1; i <= p; i++)        for(int j = 1; j <= n; j++) {            mx[j][i] = mx[j][i-1];            mi[j][i] = mi[j][i-1];            if(j+(1<<(i-1)) <= n) {                mx[j][i] = max(mx[j][i], mx[j+(1<<(i-1))][i-1]);                mi[j][i] = min(mi[j][i], mi[j+(1<<(i-1))][i-1]);            }        }}int stmin(int l, int r) {    int p = (int)(log((r-l+1)*1.0)/log(2.0));           return min(mi[l][p], mi[r-(1<<p)+1][p]);    //由于求p向下取整,导致l+p可能达不到r,但是l+p至少能区间[l, r]的一半}int stmax(int l, int r) {    int p = (int)(log((r-l+1)*1.0)/log(2.0));    return max(mx[l][p], mx[r-(1<<p)+1][p]);}int main() {    cin >> n >> q;    for(int i = 1; i <= n; i++) cin >> a[i];    stinit();    while(q--) {        int l, r;        cin >> l >> r;        printf("maxmum number is : %d, minimum number is %d\n", stmax(l, r), stmin(l, r));    }    return 0;}


如果问题求的是最值的下标,那么也很简单,只要把二维数组存的值改为下标即可。

#include <cstdio>#include <iostream>#include <cmath>#include <algorithm>using namespace std;const int maxn = 1000;int n, q;int a[maxn], mx[maxn][maxn], mi[maxn][maxn];void stinit() {    for(int i = 1; i <= n; i++) mx[i][0] = mi[i][0] = i;    int p = (int)(log(n*1.0)/log(2.0));    for(int i = 1; i <= p; i++)        for(int j = 1; j <= n; j++) {            mx[j][i] = mx[j][i-1];            mi[j][i] = mi[j][i-1];            if(j+(1<<(i-1)) <= n) {                if(a[mx[j+(1<<(i-1))][i-1]] > a[mx[j][i]])                    mx[j][i] = mx[j+(1<<(i-1))][i-1];                if(a[mi[j+(1<<(i-1))][i-1]] < a[mi[j][i]])                    mi[j][i] = mi[j+(1<<(i-1))][i-1];            }        }}int stmin(int l, int r) {    int p = (int)(log((r-l+1)*1.0)/log(2.0));    if(a[mi[l][p]] < a[mi[r-(1<<p)+1][p]]) return mi[l][p];    else return mi[r-(1<<p)+1][p];}int stmax(int l, int r) {    int p = (int)(log((r-l+1)*1.0)/log(2.0));    if(a[mi[l][p]] > a[mi[r-(1<<p)+1][p]]) return mx[l][p];    else return mx[r-(1<<p)+1][p];}int main() {    cin >> n >> q;    for(int i = 1; i <= n; i++) cin >> a[i];    stinit();    while(q--) {        int l, r;        cin >> l >> r;        printf("maxmum number's location is : %d, minimum number's location is %d\n", stmax(l, r), stmin(l, r));    }    return 0;}

0 0
原创粉丝点击