CP::map_bst

The map class using Binary Search Tree Implementation [code lang="c++"] #ifndef CP_MAP_BST_INCLUDED #define CP_MAP_BST_INCLUDED

#include //#pragma once

namespace CP {

template <typename KeyT, typename MappedT, typename CompareT = std::less > class map_bst { protected:

typedef std::pair<KeyT,MappedT> ValueT;

class node {
  friend class map_bst;
  protected:
    ValueT data;
    node  *left;
    node  *right;
    node  *parent;

    node() :
      data( ValueT() ), left( NULL ), right( NULL ), parent( NULL ) { }

    node(const ValueT& data, node* left, node* right, node* parent) :
      data ( data ), left( left ), right( right ), parent( parent ) { }
};

//-------------------------------------------------------------
class tree_iterator {
  protected:
    node* ptr;

  public:
    tree_iterator() : ptr( NULL ) { }

    tree_iterator(node *a) : ptr(a) { }

    tree_iterator& operator++() {
      if (ptr->right == NULL) {
        node *parent = ptr->parent;
        while (parent != NULL && parent->right == ptr) {
          ptr = parent;
          parent = ptr->parent;
        }
        ptr = parent;
      } else {
        ptr = ptr->right;
        while (ptr->left != NULL)
          ptr = ptr->left;
      }
      return (*this);
    }

    tree_iterator& operator--() {
      if (ptr->left == NULL) {
        node *parent = ptr->parent;
        while (parent != NULL && parent->left == ptr) {
          ptr = parent;
          parent = ptr->parent;
        }
        ptr = parent;
      } else {
        ptr = ptr->left;
        while (ptr->right != NULL)
          ptr = ptr->right;
      }
      return (*this);
    }

    tree_iterator operator++(int) {
      tree_iterator tmp(*this);
      operator++();
      return tmp;
    }

    tree_iterator operator--(int) {
      tree_iterator tmp(*this);
      operator--();
      return tmp;
    }

    ValueT& operator*()  { return ptr->data;    }
    ValueT* operator->() { return &(ptr->data); }
    bool    operator==(const tree_iterator& other) { return other.ptr == ptr; }
    bool    operator!=(const tree_iterator& other) { return other.ptr != ptr; }
};

//------------------------------------------------------------- node *mRoot; CompareT mLess; size_t mSize;

public: typedef tree_iterator iterator;

protected: node* &child_link(node* parent, const KeyT& k) { if (parent == NULL) return mRoot; return mLess(k, parent->data.first) ? parent->left : parent->right; } int compare(const KeyT& k1, const KeyT& k2) { if (mLess(k1, k2)) return -1; if (mLess(k2, k1)) return +1; return 0; } node* find_node(const KeyT& k,node* r, node* &parent) { node *ptr = r; while (ptr != NULL) { int cmp = compare(k, ptr->data.first); if (cmp == 0) return ptr; parent = ptr; ptr = cmp < 0 ? ptr->left : ptr->right; } return NULL; }

node* find_min_node(node* r) {
  //r must not be NULL
  node *min = r;
  while (min->left != NULL) {
    min = min->left;
  }
  return min;
}

node* find_max_node(node* r) {
  //r must not be NULL
  node *max = r;
  while (max->right != NULL) {
    max = max->right;
  }
  return max;
}

node* copy(node* src, node* parent) {
    if (src == NULL) return NULL;
    node* tmp   = new node();
    tmp->data   = src->data;
    tmp->left   = copy(src->left,  tmp);
    tmp->right  = copy(src->right, tmp);
    tmp->parent = parent;
    return tmp;
}
void delete_all_nodes(node *r) {
    if (r == NULL) return;
    delete_all_nodes(r->left);
    delete_all_nodes(r->right);
    delete r;
}

public: //-------------- constructor & copy operator ----------

// copy constructor
map_bst(const map_bst<KeyT,MappedT,CompareT> & other) :
  mLess(other.mLess) , mSize(other.mSize)
{
    mRoot = copy(other.mRoot, NULL);
}

// default constructor
map_bst(const CompareT& c = CompareT() ) :
  mRoot(NULL), mLess(c) , mSize(0)
{ }

// copy assignment operator using copy-and-swap idiom
map_bst<KeyT,MappedT,CompareT>& operator=(map_bst<KeyT,MappedT,CompareT> other)  {
  // other is copy-constructed which will be destruct at the end of this scope
  // we swap the content of this class to the other class and let it be destructed
  using std::swap;
  swap(this->mRoot, other.mRoot);
  swap(this->mLess, other.mLess);
  swap(this->mSize, other.mSize);
  return *this;
}

~map_bst() {
  clear();
}

//------------- capacity function -------------------
bool empty() {
  return mSize == 0;
}

size_t size() {
  return mSize;
}

//----------------- iterator ---------------
iterator begin() {
  return iterator(mRoot==NULL ? NULL : find_min_node(mRoot));
}

iterator end() {
  return iterator(NULL);
}

//----------------- modifier -------------
void clear() {
  delete_all_nodes(mRoot);
  mRoot = NULL;
  mSize = 0;
}

MappedT& operator[](const KeyT& key) {
  node *parent = NULL;
  node *ptr = find_node(key, mRoot, parent);
  if (ptr == NULL) {
    ptr = new node(std::make_pair(key,MappedT()),NULL,NULL,parent);
    child_link(parent, key) = ptr;
    mSize++;
  }
  return ptr->data.second;
}

std::pair<iterator,bool> insert(const ValueT& val) {
  node *parent = NULL;
  node *ptr = find_node(val.first,mRoot,parent);
  bool not_found = (ptr==NULL);
  if (not_found) {
    ptr = new node(val,NULL,NULL,parent);
    child_link(parent, val.first) = ptr;
    mSize++;
  }
  return std::make_pair(iterator(ptr), not_found);
}

iterator find(const KeyT &key) {
  node *parent;
  node *ptr = find_node(key,mRoot,parent);
  return ptr == NULL ? end() : iterator(ptr);
}

size_t erase(const KeyT &key) {
  if (mRoot == NULL) return 0;

  node *parent = NULL;
  node *ptr = find_node(key,mRoot,parent);
  if (ptr == NULL) return 0;
  if (ptr->left != NULL && ptr->right != NULL) {
      //have two children
    node *min = find_min_node(ptr->right);
    node * &link = child_link(min->parent, min->data.first);
    link = (min->left == NULL) ? min->right : min->left;
    if (link != NULL) link->parent = min->parent;
    std::swap(ptr->data.first, min->data.first);
    std::swap(ptr->data.second, min->data.second);
    ptr = min; // we are going to delete this node instead
  } else {
    node * &link = child_link(ptr->parent, key);
    link = (ptr->left == NULL) ? ptr->right : ptr->left;
    if (link != NULL) link->parent = ptr->parent;
  }
  delete ptr;
  mSize--;
  return 1;
}

void print_node(node *n,size_t depth) {
  if (n != NULL) {
    if (n->right != NULL && n->right->parent != n)
      std::cout << "parent of " << n->right->data.first << std::endl;
    if (n->left != NULL && n->left->parent != n)
      std::cout << "parent of " << n->left->data.first << std::endl;
    print_node(n->right,depth+1);
    for (size_t i = 0;i < depth;i++) std::cout << "--";
    std::cout << " " << n->data.first << ":" << n->data.second << std::endl;
    print_node(n->left,depth+1);
  }
}

void print() {
  std::cout << " ======== size = " << mSize << " ========= " << std::endl;
  print_node(mRoot,0);
}
bool checkParent() {
    return checkParent(mRoot);
}
bool checkParent(node *r) {
    if (r == NULL) return true;
    if (r->left != NULL && r != r->left->parent) return false;
    if (r->right != NULL && r != r->right->parent) return false;
    if (!checkParent(r->left)) return false;
    return checkParent(r->right);
}
bool checkInorder() {
    return checkInorder(mRoot);
}
bool checkInorder(node *r) {
    if (r==NULL) return true;
    if (r->left != NULL && !mLess(r->left->data.first,r->data.first)) return false;
    if (r->right != NULL && !mLess(r->data.first, r->right->data.first)) return false;
    if (!checkInorder(r->left)) return false;
    return checkInorder(r->right);
}

};

}

#endif [/code]

The testing program [code lang="c++"] #include #include #include #include #include "map_bst.h" #include <assert.h>

bool test1() { CP::map_bst<int,std::string> m; m[10] = "A"; m[8] = "B"; m[6] = "C"; m[11] = "D"; m[12] = "E";

m.print(); m.insert(std::make_pair(13,"x")); m.insert(std::make_pair(7,"x")); m[5] = "E"; m.print(); m.erase(10); m.print(); m.erase(8); m.print(); m.erase(5); m.print();

auto result = m.insert(std::make_pair(12,"x")); assert(result.second == false); while (result.first != m.end()) { std::cout << result.first->first << ": " << result.first->second << std::endl; result.first++; } return true; }

bool test2() { CP::map_bst<int,bool> m; m[1] = m[2] = m[3] = m[4] = m[20] = m [11] = m[5] = m[9] = m[7] = false; m[-4] = m[-2] = m[-1] = m[-3] = true; for (auto& x : m) { std::cout << x.first << ": " << x.second << std::endl; }

return true; }

class TestClass { public: std::string name; int value;

TestClass() : name(""), value(0) { }

TestClass(std::string n, int v) : name(n), value(v) { }

TestClass(const TestClass &other) : name(other.name), value(other.value) { }

bool operator==(const TestClass& other) const {
  std::cout << "comparing " << (*this) << " and " << other << std::endl;
  return other.name == name && other.value == value;
}

bool operator<(const TestClass& other) const {
  return name < other.name;
}

friend std::ostream& operator<<(std::ostream& os,const TestClass& c) {
  os << "(name:" << c.name << ", value:" << c.value << ")";
  return os;
}

};

bool comparator(const TestClass &a,const TestClass &b) { return a.name > b.name; }

class ComparatorClass { public: bool operator()(const TestClass &a,const TestClass &b) { return a.value < b.value; } };

typedef bool(*CompFunc)(const TestClass&, const TestClass&); // Function pointer type named "CompFunctor"

bool test3() { CP::map_bst<TestClass,float> m; CP::map_bst<TestClass,float> m2;

m[TestClass("somchai",1)] = 0.1; m[TestClass("nuttapong",1)] = 1.1; m[TestClass("nattee",1)] = 2.2; m[TestClass("vishnu",1)] = 3.1; m.print(); m[TestClass("somchai",1)] = 0.2; m[TestClass("somchai",2)] = 99; m.print(); m[TestClass("xxx",1)] = 11; m[TestClass("x1",2)] = 23; m2 = m; m2.print(); return true; }

bool test4() { std::vector<std::pair<TestClass,float> > data;

data.push_back(std::make_pair(TestClass("somchai",1),99)); data.push_back(std::make_pair(TestClass("nuttapong",1),1.1)); data.push_back(std::make_pair(TestClass("nattee",1),2.2)); data.push_back(std::make_pair(TestClass("vishnu",1),3.1)); data.push_back(std::make_pair(TestClass("xxx",1),11)); data.push_back(std::make_pair(TestClass("x1",2),23));

CP::map_bst<TestClass,float,ComparatorClass> m; for (auto& x : data) { m[x.first] = x.second; } m.print();

CP::map_bst<TestClass,float,CompFunc> m2(comparator); for (auto& x : data) { m2[x.first] = x.second; } m2.print();

return true; }

int main() {

if (test1()) std::cout << "-------------------- TEST 1 OK -------------------" << std::endl; if (test2()) std::cout << "-------------------- TEST 2 OK -------------------" << std::endl; if (test3()) std::cout << "-------------------- TEST 3 OK -------------------" << std::endl; if (test4()) std::cout << "-------------------- TEST 4 OK -------------------" << std::endl;

} [/code]