rmq问题

问题定义

RMQ(Range Maximum/Minimun Query)问题指给定一个数组A[n],多次问询区间最大/小值。

ST表做法

如果数组给定后不会再修改,可以采用ST表来做,设f[i][j]表示下标从i开始的连续2^j个元素的最值,那么显然f[i][0]=A[i],转移方程:f[i][j] = M{f[i][j-1], f[i+(1«j-1)][j-1]},由此可预处理出所有可用的f[i][j],时间复杂度为O(nlogn);然后对于每次查询(l,r),将它分割成两个可能有重叠的长度相等的区间,利用预处理出的结果在O(1)时间内给出结果。

int n, q, l, r, k, f[10005][20];
int main() {
    scanf("%d%d", &n, &q);
    for (int i = 1; i <= n; i++)
        scanf("%d", &f[i][0]);
    for (int j = 1; j < 20; j++)
        for (int i = 1; i + (1<<j) - 1 <= n; i++)
            f[i][j] = max(f[i][j-1], f[i+(1<<j-1)][j-1]);
    while (q--) {
        scanf("%d%d", &l, &r);
        for (k = 0; 1<<k+1 <= r-l; k++);
        printf("%d\n", max(f[l][k], f[r+1-(1<<k)][k]));
    }
    return 0;
}

线段树做法

如果在查询期间数组仍会被修改,可考虑用线段树,预处理建树时间为O(nlogn),后续每次更新或查询时间复杂度为O(logn)。

int f[40005];
void build(int x, int l, int r) {
    if (l == r) {
        scanf("%d", &f[x]);
        return;
    }
    int mid = (l + r) / 2;
    build(2*x, l, mid);
    build(2*x+1, mid+1, r);
    f[x] = max(f[2*x], f[2*x+1]);
}
int query(int x, int l, int r, int L, int R) {
    if (l == L && r == R)
        return f[x];
    int mid = (l + r) / 2;
    if (R <= mid) return query(2*x, l, mid, L, R);
    if (L > mid) return query(2*x+1, mid+1, r, L, R);
    return max(query(2*x, l, mid, L, mid), query(2*x+1, mid+1, r, mid+1, R));
}
int main() {
    int n, q, l, r;
    scanf("%d%d", &n, &q);
    build(1, 1, n);
    while (q--) {
        int l, r;
        scanf("%d%d", &l, &r);
        printf("%d\n", query(1, 1, n, l, r));
    }
    return 0;
}

问题扩展

用ST表方法解RMQ问题,如果要求的不是最值而是取最值元素对应的编号,如何处理?

有两种做法:一种是仍然预处理出最值的ST表,然后二分答案,比较左右区间谁的最值与整个区间的最值相同,不断缩小范围直到区间长度为1,处理每次询问的时间复杂度为O(logn);另一种做法是预处理出取最值的下标,然后O(1)回答每次询问。

这里给出第2种做法的示例代码,注意如果如果区间内有多个位置同时都为最小值,那么返回位置最小的那个。

int n, m, a[100005], p[100005][20], x, y;
int rmq(int l, int r) {
    int k;
    for (k = 0; 1 << (k + 1) <= r - l; k++);
    if (a[p[l][k]] <= a[p[r+1-(1<<k)][k]])
        return p[l][k];
    return p[r+1-(1<<k)][k];
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++)
        scanf("%d", &a[i]);
    for (int i = 1; i <= n; i++)
        p[i][0] = i;
    for (int j = 1; j < 20; j++) {
        for (int i = 1; i + (1 << j) - 1 <= n; i++) {
            if (a[p[i][j-1]] <= a[p[i+(1<<(j-1))][j-1]])
                p[i][j] = p[i][j-1];
            else
                p[i][j] = p[i+(1<<(j-1))][j-1];
        }
    }
    for (int i = 0; i < m; i++) {
        scanf("%d%d", &x, &y);
        printf("%d\n", rmq(x, y));
    }
}
Table of Contents