#ifndef MAP_H #define MAP_H #include //assert #include //ostream #include //less #include //pair template // default if argument isn't provided > class Map { struct Node { Node() {} Node(const std::pair &datum_in, Node *left_in, Node *right_in) : datum(datum_in), left(left_in), right(right_in) { } std::pair datum; Node *left; Node *right; }; public: Map() : root(nullptr) { } Map(const Map &other) : root(copy_nodes_impl(other.root)) { } Map &operator=(const Map &rhs) { if (this == &rhs) { return *this; } destroy_nodes_impl(root); root = copy_nodes_impl(rhs.root); return *this; } ~Map() { destroy_nodes_impl(root); } bool empty() const { return empty_impl(root); } size_t size() const { return static_cast(size_impl(root)); } class Iterator { public: Iterator() : root(nullptr), current_node(nullptr) {} std::pair &operator*() const { return current_node->datum; } std::pair *operator->() const { return ¤t_node->datum; } Iterator &operator++() { if (current_node->right) { current_node = min_element_impl(current_node->right); } else { current_node = min_greater_than_impl(root, current_node->datum.first, less); } return *this; } Iterator operator++(int) { Iterator result(*this); ++(*this); return result; } bool operator==(const Iterator &rhs) const { return current_node == rhs.current_node; } bool operator!=(const Iterator &rhs) const { return current_node != rhs.current_node; } private: friend class Map; Node *root; Node *current_node; Compare less; Iterator(Node *root_in, Node* current_node_in, Compare less_in) : root(root_in), current_node(current_node_in), less(less_in) { } }; Iterator begin() const { if (root == nullptr) { return Iterator(); } return Iterator(root, min_element_impl(root), less); } Iterator end() const { return Iterator(); } Iterator min_element() const { return Iterator(root, min_element_impl(root), less); } Iterator max_element() const { return Iterator(root, max_element_impl(root), less); } Iterator min_greater_than(const Key &value) const { return Iterator(root, min_greater_than_impl(root, value, less), less); } Iterator find(const Key &query) const { return Iterator(root, find_impl(root, query, less), less); } std::pair insert(const std::pair &item) { assert(find(item.first) == end()); root = insert_impl(root, item, less); return {find(item.first), true}; } Value &operator[](const Key &key) { Iterator it = find(key); if (it == end()) { it = insert({key, Value()}).first; } return (*it).second; } private: Node *root; Compare less; static bool empty_impl(const Node *node) { return (!node); } static int size_impl(const Node *node) { if (!node) { return 0; } else { return 1 + size_impl(node->left) + size_impl(node->right); } } static Node *copy_nodes_impl(Node *node) { if (!node) { return nullptr; } else { Node *copy_node = new Node(node->datum, copy_nodes_impl(node->left), copy_nodes_impl(node->right)); return copy_node; } } static void destroy_nodes_impl(Node *node) { if (node) { destroy_nodes_impl(node->left); destroy_nodes_impl(node->right); delete node; } } static Node * find_impl(Node *node, const Key &query, Compare less) { if (!node) { return nullptr; } if (less(query, node->datum.first)) { return find_impl(node->left, query, less); } else if (less(node->datum.first, query)) { return find_impl(node->right, query, less); } else { return node; } } static Node * insert_impl(Node *node, const std::pair &item, Compare less) { if (!node) { Node *new_node = new Node(item, nullptr, nullptr); return new_node; } if (less(item.first, node->datum.first)) { node->left = insert_impl(node->left, item, less); } else { node->right = insert_impl(node->right, item, less); } return node; } static Node * min_element_impl(Node *node) { if (!node) { return nullptr; } if (!node->left) { return node; } else { return min_element_impl(node->left); } } static Node * max_element_impl(Node *node) { if (!node) { return nullptr; } if (!node->right) { return node; } else { return max_element_impl(node->right); } } static Node * min_greater_than_impl(Node *node, const Key &val, Compare less) { if (!node) { return nullptr; } else if (!less(val, node->datum.first)) { return min_greater_than_impl(node->right, val, less); } else { Node *left_check = min_greater_than_impl(node->left, val, less); if (!left_check) { return node; } else { return left_check; } } return nullptr; } }; #endif