CP::map_avl

The map class using AVL Tree implementation [code lang="c++"] #ifndef CP_MAP_AVL_INCLUDED #define CP_MAP_AVL_INCLUDED

#include #include <assert.h> //#pragma once

namespace CP {

template <typename KeyT, typename MappedT, typename CompareT = std::less > class map_avl { protected:

typedef std::pair<KeyT,MappedT> ValueT;

class node {
  friend class map_avl;
  protected:
    ValueT data;
    node  *left;
    node  *right;
    node  *parent;
    int    height;

    node() :
      data( ValueT() ), left( NULL ), right( NULL ), parent( NULL ), height(0) { }

    node(const ValueT& data, node* left, node* right, node* parent) :
      data ( data ), left( left ), right( right ), parent( parent ) {
        set_height();
    }

    int get_height(node *n) { // äÁè OO ?
        return (n == NULL ? -1 : n->height);
    }
    void set_height() {
        int hL = get_height(this->left);
        int hR = get_height(this->right);
        height = 1 + (hL > hR ? hL : hR);
    }
    int balance_value() {
        return get_height(this->right) - get_height(this->left);
    }
    void set_left(node *n) {
        this->left = n;
        if (n != NULL) this->left->parent = this;
    }
    void set_right(node *n) {
        this->right = n;
        if (n != NULL) this->right->parent = this;
    }
};

//-------------------------------------------------------------
class tree_iterator {
  protected:
    node* ptr;

  public:
    tree_iterator() : ptr( NULL ) { }

    tree_iterator(node *a) : ptr(a) { }

    tree_iterator& operator++() {
      if (ptr->right == NULL) {
        node *parent = ptr->parent;
        while (parent != NULL && parent->right == ptr) {
          ptr = parent;
          parent = ptr->parent;
        }
        ptr = parent;
      } else {
        ptr = ptr->right;
        while (ptr->left != NULL)
          ptr = ptr->left;
      }
      return (*this);
    }

    tree_iterator& operator--() {
      if (ptr->left == NULL) {
        node *parent = ptr->parent;
        while (parent != NULL && parent->left == ptr) {
          ptr = parent;
          parent = ptr->parent;
        }
        ptr = parent;
      } else {
        ptr = ptr->left;
        while (ptr->right != NULL)
          ptr = ptr->right;
      }
      return (*this);
    }

    tree_iterator operator++(int) {
      tree_iterator tmp(*this);
      operator++();
      return tmp;
    }

    tree_iterator operator--(int) {
      tree_iterator tmp(*this);
      operator--();
      return tmp;
    }

    ValueT& operator*()  { return ptr->data;    }
    ValueT* operator->() { return &(ptr->data); }
    bool    operator==(const tree_iterator& other) { return other.ptr == ptr; }
    bool    operator!=(const tree_iterator& other) { return other.ptr != ptr; }
};

//------------------------------------------------------------- node *mRoot; CompareT mLess; size_t mSize;

public: typedef tree_iterator iterator;

protected: int compare(const KeyT& k1, const KeyT& k2) { if (mLess(k1, k2)) return -1; if (mLess(k2, k1)) return +1; return 0; } node * find_node(const KeyT & key, node r, node &parent) { if (r == NULL) return NULL; int cmp = compare(key, r->data.first); if (cmp == 0) return r; parent = r; return find_node(key, cmp < 0 ? r->left : r->right, parent); } node find_min_node(node r) { //r must not be NULL node min = r; while (min->left != NULL) { min = min->left; } return min; } node find_max_node(node* r) { //r must not be NULL node max = r; while (max->right != NULL) { max = max->right; } return max; } node copy(node* src, node* parent) { if (src == NULL) return NULL; node* tmp = new node(); tmp->data = src->data; tmp->left = copy(src->left, tmp); tmp->right = copy(src->right, tmp); tmp->parent = parent; return tmp; } void delete_all_nodes(node r) { if (r == NULL) return; delete_all_nodes(r->left); delete_all_nodes(r->right); delete r; } node * rotate_left_child(node * r) { node * new_root = r->left; r->set_left(new_root->right); new_root->set_right(r); new_root->right->set_height(); new_root->set_height(); return new_root; } node * rotate_right_child(node * r) { node * new_root = r->right; r->set_right(new_root->left); new_root->set_left(r); new_root->left->set_height(); new_root->set_height(); return new_root; } node * rebalance(node * r) { if (r == NULL) return r; int balance = r->balance_value(); if (balance == -2) { if (r->left->balance_value() == 1) { r->set_left(rotate_right_child(r->left)); } r = rotate_left_child(r); } else if (balance == 2) { if (r->right->balance_value() == -1) { r->set_right(rotate_left_child(r->right)); } r = rotate_right_child(r); } r->set_height(); return r; } node insert(const ValueT& val, node *r, node * &ptr) { if (r == NULL) { mSize++; ptr = r = new node(val,NULL,NULL,NULL); } else { int cmp = compare(val.first, r->data.first); if (cmp == 0) ptr = r; else if (cmp < 0) { r->set_left(insert(val, r->left, ptr)); } else { r->set_right(insert(val, r->right, ptr)); } } r = rebalance(r); return r; } node *erase(const KeyT &key, node *r) { if (r == NULL) return NULL; int cmp = compare(key, r->data.first); if (cmp < 0) { r->set_left(erase(key, r->left)); } else if (cmp > 0) { r->set_right(erase(key, r->right)); } else { if (r->left == NULL || r->right == NULL) { node *n = r; r = (r->left == NULL ? r->right : r->left); delete n; mSize--; } else { node * m = r->right; while (m->left != NULL) m = m->left; std::swap(r->data.first, m->data.first); std::swap(r->data.second, m->data.second); r->set_right(erase(m->data.first, r->right)); } } r = rebalance(r); return r; } public: //-------------- constructor & copy operator ----------

// copy constructor
map_avl(const map_avl<KeyT,MappedT,CompareT> & other) :
  mLess(other.mLess) , mSize(other.mSize)
{
    mRoot = copy(other.mRoot, NULL);
}

// default constructor
map_avl(const CompareT& c = CompareT() ) :
  mRoot(NULL), mLess(c) , mSize(0)
{ }

// copy assignment operator using copy-and-swap idiom
map_avl<KeyT,MappedT,CompareT>& operator=(map_avl<KeyT,MappedT,CompareT> other)  {
  // other is copy-constructed which will be destruct at the end of this scope
  // we swap the content of this class to the other class and let it be destructed
  using std::swap;
  swap(this->mRoot, other.mRoot);
  swap(this->mLess, other.mLess);
  swap(this->mSize, other.mSize);
  return *this;
}

~map_avl() {
  clear();
}

bool empty() {
  return mSize == 0;
}

size_t size() {
  return mSize;
}

iterator begin() {
  return iterator(mRoot==NULL ? NULL : find_min_node(mRoot));
}

iterator end() {
  return iterator(NULL);
}

iterator find(const KeyT &key) {
  node *parent = NULL;
  node *ptr = find_node(key, mRoot, parent);
  return ptr == NULL ? end() : iterator(ptr);
}

void clear() {
  delete_all_nodes(mRoot);
  mRoot = NULL;
  mSize = 0;
}

MappedT& operator[](const KeyT& key) {
  std::pair<iterator,bool> result = insert(std::make_pair(key,MappedT()));
  return result.first->second;
}

std::pair<iterator,bool> insert(const ValueT& val) {
  node *ptr = NULL;
  size_t s = mSize;
  mRoot = insert(val, mRoot, ptr);
  mRoot->parent = NULL;
  return std::make_pair(iterator(ptr),(mSize > s));
}

size_t erase(const KeyT &key) {
  size_t s = mSize;
  mRoot = erase(key, mRoot);
  return s == mSize ? 0 : 1;
}
//----------------------------------------------------------------
void print_node(node *n,size_t depth) {
  if (n != NULL) {
    if (n->right != NULL && n->right->parent != n)
      std::cout << "parent of " << n->right->data.first << std::endl;
    if (n->left != NULL && n->left->parent != n)
      std::cout << "parent of " << n->left->data.first << std::endl;
    print_node(n->right,depth+1);
    for (size_t i = 0;i < depth;i++) std::cout << "--";
    std::cout << " " << n->data.first << ":" << n->data.second << std::endl;
    print_node(n->left,depth+1);
  }
}

void print() {
  std::cout << " ======== size = " << mSize << " ========= " << std::endl;
  print_node(mRoot,0);
}
bool checkParent() {
    return checkParent(mRoot);
}
bool checkParent(node *r) {
    if (r == NULL) return true;
    if (r->left != NULL && r != r->left->parent) return false;
    if (r->right != NULL && r != r->right->parent) return false;
    if (!checkParent(r->left)) return false;
    return checkParent(r->right);
}
bool checkInorder() {
    return checkInorder(mRoot);
}
bool checkInorder(node *r) {
    if (r==NULL) return true;
    if (r->left != NULL && !mLess(r->left->data.first,r->data.first)) return false;
    if (r->right != NULL && !mLess(r->data.first, r->right->data.first)) return false;
    if (!checkInorder(r->left)) return false;
    return checkInorder(r->right);
}
int height(node *r) {
    if (r == NULL) return -1;
    int hl = height(r->left);
    int hr = height(r->right);
    return 1 + (hl > hr ? hl : hr);
}
int height() {
    if (mRoot == NULL) return -1;
    assert(height(mRoot) == mRoot->height);
    return mRoot->height;
}

};

}

#endif

[/code]

The testing code [code lang="c++"] #include #include #include #include #include "map_avl.h" #include <assert.h> #include <stdlib.h> #include #include

bool test1() { CP::map_avl<int,std::string> m; m[10] = "A"; m[8] = "B"; m[6] = "C"; m[11] = "D"; m[12] = "E";

m.print(); m.insert(std::make_pair(13,"x")); m.insert(std::make_pair(7,"x")); m[5] = "E"; m.print(); m.erase(10); m.print(); m.erase(8); m.print(); m.erase(5); m.print();

auto result = m.insert(std::make_pair(12,"x")); assert(result.second == false); while (result.first != m.end()) { std::cout << result.first->first << ": " << result.first->second << std::endl; result.first++; } return true; }

bool test2() { CP::map_avl<int,bool> m; m[1] = m[2] = m[3] = m[4] = m[20] = m [11] = m[5] = m[9] = m[7] = false; m[-4] = m[-2] = m[-1] = m[-3] = true; for (auto& x : m) { std::cout << x.first << ": " << x.second << std::endl; }

return true; }

class TestClass { public: std::string name; int value;

TestClass() : name(""), value(0) { }

TestClass(std::string n, int v) : name(n), value(v) { }

TestClass(const TestClass &other) : name(other.name), value(other.value) { }

bool operator==(const TestClass& other) const {
  std::cout << "comparing " << (*this) << " and " << other << std::endl;
  return other.name == name && other.value == value;
}

bool operator<(const TestClass& other) const {
  return name < other.name;
}

friend std::ostream& operator<<(std::ostream& os,const TestClass& c) {
  os << "(name:" << c.name << ", value:" << c.value << ")";
  return os;
}

};

bool comparator(const TestClass &a,const TestClass &b) { return a.name > b.name; }

class ComparatorClass { public: bool operator()(const TestClass &a,const TestClass &b) { return a.value < b.value; } };

typedef bool(*CompFunc)(const TestClass&, const TestClass&); // Function pointer type named "CompFunctor"

bool test3() { CP::map_avl<TestClass,float> m; CP::map_avl<TestClass,float> m2;

m[TestClass("somchai",1)] = 0.1; m[TestClass("nuttapong",1)] = 1.1; m[TestClass("nattee",1)] = 2.2; m[TestClass("vishnu",1)] = 3.1; m.print(); m[TestClass("somchai",1)] = 0.2; m[TestClass("somchai",2)] = 99; m.print(); m[TestClass("xxx",1)] = 11; m[TestClass("x1",2)] = 23; m2 = m; m2.print(); return true; }

bool test4() { std::vector<std::pair<TestClass,float> > data;

data.push_back(std::make_pair(TestClass("somchai",1),99)); data.push_back(std::make_pair(TestClass("nuttapong",1),1.1)); data.push_back(std::make_pair(TestClass("nattee",1),2.2)); data.push_back(std::make_pair(TestClass("vishnu",1),3.1)); data.push_back(std::make_pair(TestClass("xxx",1),11)); data.push_back(std::make_pair(TestClass("x1",2),23));

CP::map_avl<TestClass,float,ComparatorClass> m; for (auto& x : data) { m[x.first] = x.second; } m.print();

CP::map_avl<TestClass,float,CompFunc> m2(comparator); for (auto& x : data) { m2[x.first] = x.second; } m2.print();

return true; } bool test5() { CP::map_avl<int, int> m1; m1[4] = 1; m1[3] = 2; m1[6] = 9; CP::map_avl<int, int> m2(m1); m1[10] = 13;

CP::map_avl<int, int>::iterator it = m1.begin(); if ((it->first != 3) || (it->second != 2)) return false; it++; if ((it->first != 4) || (it->second != 1)) return false; it++; if ((it->first != 6) || (it->second != 9)) return false; it++; if ((it->first != 10) || (it->second != 13)) return false; it++; if (it != m1.end()) return false; it = m2.begin(); if ((it->first != 3) || (it->second != 2)) return false; it++; if ((it->first != 4) || (it->second != 1)) return false; it++; if ((it->first != 6) || (it->second != 9)) return false; it++; if (it != m2.end()) return false; m1.print(); m2.print(); return true; } //============================================== bool test_tree_sort() { int n = 10000; float *d1 = new float[n]; float *d2 = new float[n]; for (int i=0; i<n; i++) { d1[i] = rand(); d2[i] = d1[i]; } CP::map_avl<float,int> m; for (int i=0; i<n; i++) m[d1[i]]++; int k = 0; for (auto& v : m) { for (int i=0; i<v.second; i++) { d1[k++] = v.first; } } std::sort(d2, d2+n, [](const float& a, const float& b) -> bool {return a < b;} ); for (int i=0; i<n; i++) { assert(d1[i] == d2[i]); } return true; } //============================================== bool test_big_tree() { std::cout << "Testing big tree"; CP::map_avl<int, int> m; assert(m.erase(99) == 0); size_t n = 10000000; int *d = new int[n]; for (size_t i=0; i<n; i++) d[i] = i; for (size_t i=0; i<n; i++) { size_t j = rand()%n; int t = d[i]; d[i] = d[j]; d[j] = t; } for (size_t i=0; i<n/4; i++) { m[d[i]] = 1; assert(m[d[i]] == 1); if (i % (n/10) == 0) std::cout << "."; } for (size_t i=n/4; i<n/2; i++) { assert(m[d[i]] == 0); m[d[i]] = 1; assert(m[d[i]] == 1); if (i % (n/10) == 0) std::cout << "."; } for (size_t i=n/2; i<n; i++) { std::pair<CP::map_avl<int, int>::iterator,bool> p = m.insert(std::make_pair(d[i],1)); assert(p.second == true); assert((p.first)->first == d[i]);

p = m.insert(std::make_pair(d[i],1));
assert(p.second == false);
assert((p.first)->first == d[i]);
if (i % (n/10) == 0) std::cout << ".";

} assert(m.size() == n); for (size_t i=0; i<n/5; i+=2) assert(m.erase(d[i])==1); assert(m.size() == n - n/10); std::cout << "\nsize = " << m.size() << ", height = " << m.height() << std::endl; assert(m.checkParent()); assert(m.checkInorder()); int last = -999999999; for (auto & x : m) { assert(last < x.first); last = x.first; } return true; } //================================================= int main() { if (test1()) std::cout << "-------------------- TEST 1 OK -------------------" << std::endl; if (test2()) std::cout << "-------------------- TEST 2 OK -------------------" << std::endl; if (test3()) std::cout << "-------------------- TEST 3 OK -------------------" << std::endl; if (test4()) std::cout << "-------------------- TEST 4 OK -------------------" << std::endl; if (test5()) std::cout << "-------------------- TEST 5 OK -------------------" << std::endl; if (test_tree_sort()) std::cout << "-------------------- TEST TreeSort OK -------------------" << std::endl; if (test_big_tree()) std::cout << "-------------------- TEST BigTree OK -------------------" << std::endl; }

[/code]