The map class using AVL Tree implementation [code lang="c++"] #ifndef CP_MAP_AVL_INCLUDED #define CP_MAP_AVL_INCLUDED
#include
namespace CP {
template <typename KeyT,
typename MappedT,
typename CompareT = std::less
typedef std::pair<KeyT,MappedT> ValueT;
class node {
friend class map_avl;
protected:
ValueT data;
node *left;
node *right;
node *parent;
int height;
node() :
data( ValueT() ), left( NULL ), right( NULL ), parent( NULL ), height(0) { }
node(const ValueT& data, node* left, node* right, node* parent) :
data ( data ), left( left ), right( right ), parent( parent ) {
set_height();
}
int get_height(node *n) { // äÁè OO ?
return (n == NULL ? -1 : n->height);
}
void set_height() {
int hL = get_height(this->left);
int hR = get_height(this->right);
height = 1 + (hL > hR ? hL : hR);
}
int balance_value() {
return get_height(this->right) - get_height(this->left);
}
void set_left(node *n) {
this->left = n;
if (n != NULL) this->left->parent = this;
}
void set_right(node *n) {
this->right = n;
if (n != NULL) this->right->parent = this;
}
};
//-------------------------------------------------------------
class tree_iterator {
protected:
node* ptr;
public:
tree_iterator() : ptr( NULL ) { }
tree_iterator(node *a) : ptr(a) { }
tree_iterator& operator++() {
if (ptr->right == NULL) {
node *parent = ptr->parent;
while (parent != NULL && parent->right == ptr) {
ptr = parent;
parent = ptr->parent;
}
ptr = parent;
} else {
ptr = ptr->right;
while (ptr->left != NULL)
ptr = ptr->left;
}
return (*this);
}
tree_iterator& operator--() {
if (ptr->left == NULL) {
node *parent = ptr->parent;
while (parent != NULL && parent->left == ptr) {
ptr = parent;
parent = ptr->parent;
}
ptr = parent;
} else {
ptr = ptr->left;
while (ptr->right != NULL)
ptr = ptr->right;
}
return (*this);
}
tree_iterator operator++(int) {
tree_iterator tmp(*this);
operator++();
return tmp;
}
tree_iterator operator--(int) {
tree_iterator tmp(*this);
operator--();
return tmp;
}
ValueT& operator*() { return ptr->data; }
ValueT* operator->() { return &(ptr->data); }
bool operator==(const tree_iterator& other) { return other.ptr == ptr; }
bool operator!=(const tree_iterator& other) { return other.ptr != ptr; }
};
//------------------------------------------------------------- node *mRoot; CompareT mLess; size_t mSize;
public: typedef tree_iterator iterator;
protected: int compare(const KeyT& k1, const KeyT& k2) { if (mLess(k1, k2)) return -1; if (mLess(k2, k1)) return +1; return 0; } node * find_node(const KeyT & key, node r, node &parent) { if (r == NULL) return NULL; int cmp = compare(key, r->data.first); if (cmp == 0) return r; parent = r; return find_node(key, cmp < 0 ? r->left : r->right, parent); } node find_min_node(node r) { //r must not be NULL node min = r; while (min->left != NULL) { min = min->left; } return min; } node find_max_node(node* r) { //r must not be NULL node max = r; while (max->right != NULL) { max = max->right; } return max; } node copy(node* src, node* parent) { if (src == NULL) return NULL; node* tmp = new node(); tmp->data = src->data; tmp->left = copy(src->left, tmp); tmp->right = copy(src->right, tmp); tmp->parent = parent; return tmp; } void delete_all_nodes(node r) { if (r == NULL) return; delete_all_nodes(r->left); delete_all_nodes(r->right); delete r; } node * rotate_left_child(node * r) { node * new_root = r->left; r->set_left(new_root->right); new_root->set_right(r); new_root->right->set_height(); new_root->set_height(); return new_root; } node * rotate_right_child(node * r) { node * new_root = r->right; r->set_right(new_root->left); new_root->set_left(r); new_root->left->set_height(); new_root->set_height(); return new_root; } node * rebalance(node * r) { if (r == NULL) return r; int balance = r->balance_value(); if (balance == -2) { if (r->left->balance_value() == 1) { r->set_left(rotate_right_child(r->left)); } r = rotate_left_child(r); } else if (balance == 2) { if (r->right->balance_value() == -1) { r->set_right(rotate_left_child(r->right)); } r = rotate_right_child(r); } r->set_height(); return r; } node insert(const ValueT& val, node *r, node * &ptr) { if (r == NULL) { mSize++; ptr = r = new node(val,NULL,NULL,NULL); } else { int cmp = compare(val.first, r->data.first); if (cmp == 0) ptr = r; else if (cmp < 0) { r->set_left(insert(val, r->left, ptr)); } else { r->set_right(insert(val, r->right, ptr)); } } r = rebalance(r); return r; } node *erase(const KeyT &key, node *r) { if (r == NULL) return NULL; int cmp = compare(key, r->data.first); if (cmp < 0) { r->set_left(erase(key, r->left)); } else if (cmp > 0) { r->set_right(erase(key, r->right)); } else { if (r->left == NULL || r->right == NULL) { node *n = r; r = (r->left == NULL ? r->right : r->left); delete n; mSize--; } else { node * m = r->right; while (m->left != NULL) m = m->left; std::swap(r->data.first, m->data.first); std::swap(r->data.second, m->data.second); r->set_right(erase(m->data.first, r->right)); } } r = rebalance(r); return r; } public: //-------------- constructor & copy operator ----------
// copy constructor
map_avl(const map_avl<KeyT,MappedT,CompareT> & other) :
mLess(other.mLess) , mSize(other.mSize)
{
mRoot = copy(other.mRoot, NULL);
}
// default constructor
map_avl(const CompareT& c = CompareT() ) :
mRoot(NULL), mLess(c) , mSize(0)
{ }
// copy assignment operator using copy-and-swap idiom
map_avl<KeyT,MappedT,CompareT>& operator=(map_avl<KeyT,MappedT,CompareT> other) {
// other is copy-constructed which will be destruct at the end of this scope
// we swap the content of this class to the other class and let it be destructed
using std::swap;
swap(this->mRoot, other.mRoot);
swap(this->mLess, other.mLess);
swap(this->mSize, other.mSize);
return *this;
}
~map_avl() {
clear();
}
bool empty() {
return mSize == 0;
}
size_t size() {
return mSize;
}
iterator begin() {
return iterator(mRoot==NULL ? NULL : find_min_node(mRoot));
}
iterator end() {
return iterator(NULL);
}
iterator find(const KeyT &key) {
node *parent = NULL;
node *ptr = find_node(key, mRoot, parent);
return ptr == NULL ? end() : iterator(ptr);
}
void clear() {
delete_all_nodes(mRoot);
mRoot = NULL;
mSize = 0;
}
MappedT& operator[](const KeyT& key) {
std::pair<iterator,bool> result = insert(std::make_pair(key,MappedT()));
return result.first->second;
}
std::pair<iterator,bool> insert(const ValueT& val) {
node *ptr = NULL;
size_t s = mSize;
mRoot = insert(val, mRoot, ptr);
mRoot->parent = NULL;
return std::make_pair(iterator(ptr),(mSize > s));
}
size_t erase(const KeyT &key) {
size_t s = mSize;
mRoot = erase(key, mRoot);
return s == mSize ? 0 : 1;
}
//----------------------------------------------------------------
void print_node(node *n,size_t depth) {
if (n != NULL) {
if (n->right != NULL && n->right->parent != n)
std::cout << "parent of " << n->right->data.first << std::endl;
if (n->left != NULL && n->left->parent != n)
std::cout << "parent of " << n->left->data.first << std::endl;
print_node(n->right,depth+1);
for (size_t i = 0;i < depth;i++) std::cout << "--";
std::cout << " " << n->data.first << ":" << n->data.second << std::endl;
print_node(n->left,depth+1);
}
}
void print() {
std::cout << " ======== size = " << mSize << " ========= " << std::endl;
print_node(mRoot,0);
}
bool checkParent() {
return checkParent(mRoot);
}
bool checkParent(node *r) {
if (r == NULL) return true;
if (r->left != NULL && r != r->left->parent) return false;
if (r->right != NULL && r != r->right->parent) return false;
if (!checkParent(r->left)) return false;
return checkParent(r->right);
}
bool checkInorder() {
return checkInorder(mRoot);
}
bool checkInorder(node *r) {
if (r==NULL) return true;
if (r->left != NULL && !mLess(r->left->data.first,r->data.first)) return false;
if (r->right != NULL && !mLess(r->data.first, r->right->data.first)) return false;
if (!checkInorder(r->left)) return false;
return checkInorder(r->right);
}
int height(node *r) {
if (r == NULL) return -1;
int hl = height(r->left);
int hr = height(r->right);
return 1 + (hl > hr ? hl : hr);
}
int height() {
if (mRoot == NULL) return -1;
assert(height(mRoot) == mRoot->height);
return mRoot->height;
}
};
}
#endif
[/code]
The testing code
[code lang="c++"]
#include
bool test1() { CP::map_avl<int,std::string> m; m[10] = "A"; m[8] = "B"; m[6] = "C"; m[11] = "D"; m[12] = "E";
m.print(); m.insert(std::make_pair(13,"x")); m.insert(std::make_pair(7,"x")); m[5] = "E"; m.print(); m.erase(10); m.print(); m.erase(8); m.print(); m.erase(5); m.print();
auto result = m.insert(std::make_pair(12,"x")); assert(result.second == false); while (result.first != m.end()) { std::cout << result.first->first << ": " << result.first->second << std::endl; result.first++; } return true; }
bool test2() { CP::map_avl<int,bool> m; m[1] = m[2] = m[3] = m[4] = m[20] = m [11] = m[5] = m[9] = m[7] = false; m[-4] = m[-2] = m[-1] = m[-3] = true; for (auto& x : m) { std::cout << x.first << ": " << x.second << std::endl; }
return true; }
class TestClass { public: std::string name; int value;
TestClass() : name(""), value(0) { }
TestClass(std::string n, int v) : name(n), value(v) { }
TestClass(const TestClass &other) : name(other.name), value(other.value) { }
bool operator==(const TestClass& other) const {
std::cout << "comparing " << (*this) << " and " << other << std::endl;
return other.name == name && other.value == value;
}
bool operator<(const TestClass& other) const {
return name < other.name;
}
friend std::ostream& operator<<(std::ostream& os,const TestClass& c) {
os << "(name:" << c.name << ", value:" << c.value << ")";
return os;
}
};
bool comparator(const TestClass &a,const TestClass &b) { return a.name > b.name; }
class ComparatorClass { public: bool operator()(const TestClass &a,const TestClass &b) { return a.value < b.value; } };
typedef bool(*CompFunc)(const TestClass&, const TestClass&); // Function pointer type named "CompFunctor"
bool test3() { CP::map_avl<TestClass,float> m; CP::map_avl<TestClass,float> m2;
m[TestClass("somchai",1)] = 0.1; m[TestClass("nuttapong",1)] = 1.1; m[TestClass("nattee",1)] = 2.2; m[TestClass("vishnu",1)] = 3.1; m.print(); m[TestClass("somchai",1)] = 0.2; m[TestClass("somchai",2)] = 99; m.print(); m[TestClass("xxx",1)] = 11; m[TestClass("x1",2)] = 23; m2 = m; m2.print(); return true; }
bool test4() { std::vector<std::pair<TestClass,float> > data;
data.push_back(std::make_pair(TestClass("somchai",1),99)); data.push_back(std::make_pair(TestClass("nuttapong",1),1.1)); data.push_back(std::make_pair(TestClass("nattee",1),2.2)); data.push_back(std::make_pair(TestClass("vishnu",1),3.1)); data.push_back(std::make_pair(TestClass("xxx",1),11)); data.push_back(std::make_pair(TestClass("x1",2),23));
CP::map_avl<TestClass,float,ComparatorClass> m; for (auto& x : data) { m[x.first] = x.second; } m.print();
CP::map_avl<TestClass,float,CompFunc> m2(comparator); for (auto& x : data) { m2[x.first] = x.second; } m2.print();
return true; } bool test5() { CP::map_avl<int, int> m1; m1[4] = 1; m1[3] = 2; m1[6] = 9; CP::map_avl<int, int> m2(m1); m1[10] = 13;
CP::map_avl<int, int>::iterator it = m1.begin(); if ((it->first != 3) || (it->second != 2)) return false; it++; if ((it->first != 4) || (it->second != 1)) return false; it++; if ((it->first != 6) || (it->second != 9)) return false; it++; if ((it->first != 10) || (it->second != 13)) return false; it++; if (it != m1.end()) return false; it = m2.begin(); if ((it->first != 3) || (it->second != 2)) return false; it++; if ((it->first != 4) || (it->second != 1)) return false; it++; if ((it->first != 6) || (it->second != 9)) return false; it++; if (it != m2.end()) return false; m1.print(); m2.print(); return true; } //============================================== bool test_tree_sort() { int n = 10000; float *d1 = new float[n]; float *d2 = new float[n]; for (int i=0; i<n; i++) { d1[i] = rand(); d2[i] = d1[i]; } CP::map_avl<float,int> m; for (int i=0; i<n; i++) m[d1[i]]++; int k = 0; for (auto& v : m) { for (int i=0; i<v.second; i++) { d1[k++] = v.first; } } std::sort(d2, d2+n, [](const float& a, const float& b) -> bool {return a < b;} ); for (int i=0; i<n; i++) { assert(d1[i] == d2[i]); } return true; } //============================================== bool test_big_tree() { std::cout << "Testing big tree"; CP::map_avl<int, int> m; assert(m.erase(99) == 0); size_t n = 10000000; int *d = new int[n]; for (size_t i=0; i<n; i++) d[i] = i; for (size_t i=0; i<n; i++) { size_t j = rand()%n; int t = d[i]; d[i] = d[j]; d[j] = t; } for (size_t i=0; i<n/4; i++) { m[d[i]] = 1; assert(m[d[i]] == 1); if (i % (n/10) == 0) std::cout << "."; } for (size_t i=n/4; i<n/2; i++) { assert(m[d[i]] == 0); m[d[i]] = 1; assert(m[d[i]] == 1); if (i % (n/10) == 0) std::cout << "."; } for (size_t i=n/2; i<n; i++) { std::pair<CP::map_avl<int, int>::iterator,bool> p = m.insert(std::make_pair(d[i],1)); assert(p.second == true); assert((p.first)->first == d[i]);
p = m.insert(std::make_pair(d[i],1));
assert(p.second == false);
assert((p.first)->first == d[i]);
if (i % (n/10) == 0) std::cout << ".";
} assert(m.size() == n); for (size_t i=0; i<n/5; i+=2) assert(m.erase(d[i])==1); assert(m.size() == n - n/10); std::cout << "\nsize = " << m.size() << ", height = " << m.height() << std::endl; assert(m.checkParent()); assert(m.checkInorder()); int last = -999999999; for (auto & x : m) { assert(last < x.first); last = x.first; } return true; } //================================================= int main() { if (test1()) std::cout << "-------------------- TEST 1 OK -------------------" << std::endl; if (test2()) std::cout << "-------------------- TEST 2 OK -------------------" << std::endl; if (test3()) std::cout << "-------------------- TEST 3 OK -------------------" << std::endl; if (test4()) std::cout << "-------------------- TEST 4 OK -------------------" << std::endl; if (test5()) std::cout << "-------------------- TEST 5 OK -------------------" << std::endl; if (test_tree_sort()) std::cout << "-------------------- TEST TreeSort OK -------------------" << std::endl; if (test_big_tree()) std::cout << "-------------------- TEST BigTree OK -------------------" << std::endl; }
[/code]
- Log in to post comments