二叉搜索树 (c++非递归版)

来源:互联网 发布:淘宝单笔如何部分退款 编辑:程序博客网 时间:2024/06/08 12:19

400多行代码(其实还有好多API没完成),请接受恐惧的代码吧得到节点count是递归的(为了不太复杂),下面也给出了测试代码

BST.h

#pragma once#include <stdexcept>#include <stack>template<typename Key, typename Value>class BST{private:    class Node    {    public:        Node* left = nullptr;        Node* right = nullptr;        Key key;        Value value;        int count = 0;    public:        Node(const Key& k, const Value& v, const int& c) :key(k), value(v), count(c)        {        }    };    Node* root = nullptr;private:/********************************************函数名称: size函数说明: 返回二叉树节点个数返回值:   int********************************************/    int size(Node* r)    {        if (r == nullptr)            return 0;        return r->count;    }/******************************************函数名称:  sync_size函数说明:  更新各节点的count值返回值:    void*******************************************/    void sync_size(Node* r)    {        if (r == nullptr)            return;        sync_size(r->left);        sync_size(r->right);        r->count = size(r->left) + size(r->right) + 1;        return;    }/*******************************************函数名称: put函数说明: 存储节点于二叉搜索树返回值:   Node********************************************/    Node* put(Node* r, const Key& k, const Value& v)    {        if (r == nullptr)            return new Node(k, v, 1);        Node* curr = r;        while (true)        {            if (curr->key > k)            {                if (curr->left == nullptr)                {                    curr->left = new Node(k, v, 1);                    break;                }                else                    curr = curr->left;            }            else if (curr->key < k)            {                if (curr->right == nullptr)                {                    curr->right = new Node(k, v, 1);                    break;                }                else                    curr = curr->right;            }            else            {                curr->value = v;                return r;            }        }        sync_size(r);        return r;    }/*******************************************函数名称:  get函数说明:  得到二叉树key对应value返回值:    Value******************************************/    Value get(Node* r, const Key& key)    {        if (r == nullptr)            throw std::out_of_range("can't get the value of key");        while (true)        {            if (r->key > key)            {                if (r->left == nullptr)                    throw  std::out_of_range("can't get the value of key");                else                    r = r->left;            }            else if (r->key < key)            {                if (r->right == nullptr)                    throw  std::out_of_range("can't get the value of key");                else                    r = r->right;            }            else                return r->value;        }    }/******************************************函数名称: min函数说明: 取得二叉树key最小的节点返回值:   Node*********************************************/    Node* min(Node* r)    {        if (r == nullptr)            throw std::out_of_range("can't gain the min");        while (true)        {            if (r->left == nullptr)                return r;            else                r = r->left;        }    }/******************************************函数名称: max函数说明: 取得二叉树里key最大的节点返回值:   Node********************************************/    Node* max(Node* r)    {        if (r == nullptr)            throw std::out_of_range("can't gain the max");        while (true)        {            if (r->right == nullptr)                return r;            else                r = r->right;        }    }/******************************************函数名称: deleteMin函数说明: 删除二叉树中key最小的节点返回值:   Node********************************************/    Node* deleteMin(Node *r)    {        if (r == nullptr)            return nullptr;        Node* curr = r;        Node* pre = nullptr;        while (true)        {            if (curr->left == nullptr)            {                Node* temp = curr;                curr = curr->right;                delete temp;                if (pre != nullptr)                    pre->left = curr;                else                    r = curr;                sync_size(r);                return r;            }            else            {                pre = curr;                curr = curr->left;            }        }    }/******************************************函数名称: deleteMax函数说明: 删除二叉树key最大的节点返回值:   void*******************************************/    Node* deleteMax(Node *r)    {        if (r == nullptr)            return nullptr;        Node* curr = r;        Node* pre = nullptr;        while (true)        {            if (curr->right == nullptr)            {                Node* temp = curr;                curr = curr->left;                delete temp;                if (pre != nullptr)                    pre->right = curr;                else                    r = curr;                sync_size(r);                return r;            }            else            {                pre = curr;                curr = curr->right;            }        }    }/******************************************函数名称: before_display函数说明: 中序遍历二叉树打印节点value返回值:   void*******************************************/    void before_display(Node* r)    {        if (r == nullptr)            return;        stack<Node*> s;        while (r != nullptr || !s.empty())        {            if (r != nullptr)            {                s.push(r);                r = r->left;            }            else            {                r = s.top();                cout << r->value << ends;                s.pop();                r = r->right;            }        }    }/******************************************函数名称:  erase函数说明:  删除特定key的节点*******************************************/    Node* erase(Node* r, const Key& k)    {        if (r == nullptr)            return nullptr;        Node* curr = r;        Node* pre = nullptr;        while (curr != nullptr)        {            if (curr->key > k)            {                pre = curr;                curr = curr->left;            }            else if (curr->key < k)            {                pre = curr;                curr = curr->right;            }            else            {                if (curr->left == nullptr)                {                    if (pre != nullptr && pre->right == curr)                    {                        Node* temp = curr;                        curr = curr->right;                        delete temp;                        pre->right = curr;                        sync_size(r);                        return r;                    }                    else if (pre != nullptr && pre->left == curr)                    {                        Node* temp = curr;                        curr = curr->right;                        delete temp;                        pre->left = curr;                        sync_size(r);                        return r;                    }                    else                    {                        Node* temp = curr;                        curr = curr->right;                        delete temp;                        r = curr;                        sync_size(r);                    }                }                else if (curr->right == nullptr)                {                    if (pre != nullptr && pre->right == curr)                    {                        Node* temp = curr;                        curr = curr->left;                        delete temp;                        pre->right = curr;                        sync_size(r);                        return r;                    }                    else if (pre != nullptr && pre->left == curr)                    {                        Node* temp = curr;                        curr = curr->left;                        delete temp;                        pre->left = curr;                        sync_size(r);                        return r;                    }                    else                    {                        Node* temp = curr;                        curr = curr->left;                        delete temp;                        r = curr;                        sync_size(r);                        return r;                    }                }                else                {                    if (pre != nullptr)                    {                        Node* get = min(curr->right);                        get->right = deleteMin(curr->right);                        get->left = curr->left;                        if (pre->right == curr)                            pre->right = get;                        else                            pre->left = get;                        delete curr;                        sync_size(r);                        return r;                    }                    else                    {                        Node* get = min(curr->right);                        get->right = deleteMin(curr->right);                        get->left = curr->left;                        r = get;                        delete curr;                        sync_size(r);                        return r;                    }                }            }        }    }/*******************************************函数名称:  floor函数说明:  向下取整返回第一个key<= k的节点返回值:    Node*********************************************/    Node* floor(Node* r, const Key& k)    {        if (r == nullptr)            return nullptr;        Node* pre = nullptr;        while (true)        {            if (r->key == k)                return r;            else if (r->key > k)            {                if (r->left == nullptr)                {                    return pre;                }                else                    r = r->left;            }            else            {                if (r->right == nullptr)                    return r;                else                {                    pre = r;                    r = r->right;                }            }        }    }public:    int size()    {        return size(root);    }    void put(const Key& k, const int& v)    {        root = put(root, k, v);    }    Value get(const Key& k)    {        return get(root, k);    }    Key min()    {        return min(root)->key;    }    Key max()    {        return max(root)->key;    }    void deleteMin()    {        root = deleteMin(root);    }    void deleteMax()    {        root = deleteMax(root);    }    void before_display()    {        before_display(root);    }    void erase(const Key& k)    {        root = erase(root, k);    }    Key floor(const Key& k)    {        Node* ret;        if ((ret = floor(root, k)) == nullptr)            throw std::out_of_range("can't floor");        else            return ret->key;    }};

main.cpp

#include <iostream>#include "BST.h"using namespace std;int main(){    BST<double, int> bst;    for (int i = 0; i < 10; ++i)        bst.put(i + 0.1, i);    cout << "size: " << bst.size() << endl;    try    {        cout << bst.get(1.1) << endl;        bst.deleteMin();        bst.deleteMax();        bst.erase(3.1);        cout << "min: " << bst.min() << endl;        cout << "max: " << bst.max() << endl;        cout << bst.size() << endl;        cout << bst.floor(5.0) << endl;    }    catch (const exception& e)    {        cout << e.what();    }    bst.before_display();    system("pause");    return 0;}

运行:

这里写图片描述

原创粉丝点击