二叉搜索树bst

二叉搜索树BST是一种排序搜索树,每个节点的值大于左子节点而小于右子节点,支持插入、查找、删除、最小、最大、前驱、后继等操作,操作的时间复杂度与树的高度相关。

考虑到BST是其他平衡树的基础,这里对其进行简单地整理。

节点定义

struct Node {
    int key, cnt;
    Node *L, *R, *P;
    Node(int key, Node *p):key(key),cnt(1),P(p) {
        L = R = nullptr;
    }
};

其中cnt代填键值为key的记录数,因而键值允许重复。

根据键查找节点

Node* Find(Node *x, int key) {
    if (!x) return x;
    if (key == x->key) return x;
    return Find(key < x->key ? x->L : x->R, key);
}

获取子树最值节点

Node* Min(Node *x) {
    while (x && x->L)
        x = x->L;
    return x;
}
Node* Max(Node *x) {
    while (x && x->R)
        x = x->R;
    return x;
}

求前驱和后继

Node* Prev(Node *x) {
    if (!x) return x;
    if (x->L) return Max(x->L);
    Node *y = x->P;
    while (y && x == y->L)
        x = y, y = x->P;
    return y;
}
Node* Next(Node *x) {
    if (!x) return x;
    if (x->R) return Min(x->R);
    Node *y = x->P;
    while (y && x == y->R)
        x = y, y = x->P;
    return y;
}

插入节点

Node* Insert(Node *&x, Node *p, int key) {
    if (!x) return x = new Node(key, p);
    if (key == x->key) {
        //x->cnt += 1;
        return x;
    }
    return Insert(key < x->key ? x->L : x->R, x, key);
}

删除节点

Node* Delete(Node *&x, int key) {
    if (!x) return x;
    if (key != x->key)
        return Delete(key < x->key ? x->L : x->R, key);
    if (x->cnt > 1) {
        x->cnt -= 1;
        return x;
    }
    if (!x->L) {
        Node *y = x->P;
        if (x->R) x->R->P = y;
        if (y) {
            if (x == y->L)
                y->L = x->R;
            else
                y->R = x->R;
        } else x = x->R;
        return x;
    } else if (!x->R) {
        Node *y = x->P;
        if (x->L) x->L->P = y;
        if (y) {
            if (x == y->L)
                y->L = x->L;
            else
                y->R = x->L;
        } else x = x->L;
        return x;
    } else {
        Node *t = Next(x);
        x->key = t->key;
        x->cnt = t->cnt;
        return Delete(t, t->key);
    }
}

二分搜索

void LowerBound(Node *x, int key, int &ans) {
    if (!x) return;
    if (x->key >= key) ans = min(ans, x->key);
    LowerBound(x->key > key ? x->L : x->R, key, ans);
}
void UpperBound(Node *x, int key, int &ans) {
    if (!x) return;
    if (x->key > key) ans = min(ans, x->key);
    UpperBound(x->Key > key ? x->L : x->R, key, ans);
}
Table of Contents