#ifndef MAP_H
#define MAP_H
#include <cassert> //assert
#include <iostream> //ostream
#include <functional> //less
#include <utility> //pair
template <typename Key,
typename Value,
typename Compare=std::less<Key> // default if argument isn't provided
>
class Map {
struct Node {
Node() {}
Node(const std::pair<const Key, Value> &datum_in, Node *left_in, Node *right_in)
: datum(datum_in), left(left_in), right(right_in) { }
std::pair<const Key, Value> datum;
Node *left;
Node *right;
};
public:
Map()
: root(nullptr) { }
Map(const Map &other)
: root(copy_nodes_impl(other.root)) { }
Map &operator=(const Map &rhs) {
if (this == &rhs) {
return *this;
}
destroy_nodes_impl(root);
root = copy_nodes_impl(rhs.root);
return *this;
}
~Map() {
destroy_nodes_impl(root);
}
bool empty() const {
return empty_impl(root);
}
size_t size() const {
return static_cast<size_t>(size_impl(root));
}
class Iterator {
public:
Iterator()
: root(nullptr), current_node(nullptr) {}
std::pair<const Key, Value> &operator*() const {
return current_node->datum;
}
std::pair<const Key, Value> *operator->() const {
return ¤t_node->datum;
}
Iterator &operator++() {
if (current_node->right) {
current_node = min_element_impl(current_node->right);
}
else {
current_node = min_greater_than_impl(root, current_node->datum.first, less);
}
return *this;
}
Iterator operator++(int) {
Iterator result(*this);
++(*this);
return result;
}
bool operator==(const Iterator &rhs) const {
return current_node == rhs.current_node;
}
bool operator!=(const Iterator &rhs) const {
return current_node != rhs.current_node;
}
private:
friend class Map;
Node *root;
Node *current_node;
Compare less;
Iterator(Node *root_in, Node* current_node_in, Compare less_in)
: root(root_in), current_node(current_node_in), less(less_in) { }
};
Iterator begin() const {
if (root == nullptr) {
return Iterator();
}
return Iterator(root, min_element_impl(root), less);
}
Iterator end() const {
return Iterator();
}
Iterator min_element() const {
return Iterator(root, min_element_impl(root), less);
}
Iterator max_element() const {
return Iterator(root, max_element_impl(root), less);
}
Iterator min_greater_than(const Key &value) const {
return Iterator(root, min_greater_than_impl(root, value, less), less);
}
Iterator find(const Key &query) const {
return Iterator(root, find_impl(root, query, less), less);
}
std::pair<Iterator, bool> insert(const std::pair<const Key, Value> &item) {
assert(find(item.first) == end());
root = insert_impl(root, item, less);
return {find(item.first), true};
}
Value &operator[](const Key &key) {
Iterator it = find(key);
if (it == end()) {
it = insert({key, Value()}).first;
}
return (*it).second;
}
private:
Node *root;
Compare less;
static bool empty_impl(const Node *node) {
return (!node);
}
static int size_impl(const Node *node) {
if (!node) {
return 0;
} else {
return 1 + size_impl(node->left) + size_impl(node->right);
}
}
static Node *copy_nodes_impl(Node *node) {
if (!node) {
return nullptr;
} else {
Node *copy_node = new Node(node->datum, copy_nodes_impl(node->left),
copy_nodes_impl(node->right));
return copy_node;
}
}
static void destroy_nodes_impl(Node *node) {
if (node) {
destroy_nodes_impl(node->left);
destroy_nodes_impl(node->right);
delete node;
}
}
static Node * find_impl(Node *node, const Key &query, Compare less) {
if (!node) {
return nullptr;
}
if (less(query, node->datum.first)) {
return find_impl(node->left, query, less);
} else if (less(node->datum.first, query)) {
return find_impl(node->right, query, less);
} else {
return node;
}
}
static Node * insert_impl(Node *node, const std::pair<const Key, Value> &item, Compare less) {
if (!node) {
Node *new_node = new Node(item, nullptr, nullptr);
return new_node;
}
if (less(item.first, node->datum.first)) {
node->left = insert_impl(node->left, item, less);
} else {
node->right = insert_impl(node->right, item, less);
}
return node;
}
static Node * min_element_impl(Node *node) {
if (!node) {
return nullptr;
}
if (!node->left) {
return node;
} else {
return min_element_impl(node->left);
}
}
static Node * max_element_impl(Node *node) {
if (!node) {
return nullptr;
}
if (!node->right) {
return node;
} else {
return max_element_impl(node->right);
}
}
static Node * min_greater_than_impl(Node *node, const Key &val, Compare less) {
if (!node) {
return nullptr;
}
else if (!less(val, node->datum.first)) {
return min_greater_than_impl(node->right, val, less);
}
else {
Node *left_check = min_greater_than_impl(node->left, val, less);
if (!left_check) {
return node;
} else {
return left_check;
}
}
return nullptr;
}
};
#endif