Post-classifier-using-Machine-Learning / src / Map.h
Map.h
Raw
#ifndef MAP_H
#define MAP_H

#include <cassert>  //assert
#include <iostream> //ostream
#include <functional> //less
#include <utility> //pair

template <typename Key,
          typename Value,
          typename Compare=std::less<Key> // default if argument isn't provided
         >
class Map {

  struct Node {

    Node() {}

    Node(const std::pair<const Key, Value> &datum_in, Node *left_in, Node *right_in)
            : datum(datum_in), left(left_in), right(right_in) { }

    std::pair<const Key, Value> datum;
    Node *left;
    Node *right;
  };

public:

  Map()
    : root(nullptr) { }

  Map(const Map &other)
    : root(copy_nodes_impl(other.root)) { }

  Map &operator=(const Map &rhs) {
    if (this == &rhs) {
      return *this;
    }
    destroy_nodes_impl(root);
    root = copy_nodes_impl(rhs.root);
    return *this;
  }

  ~Map() {
    destroy_nodes_impl(root);
  }

  bool empty() const {
    return empty_impl(root);
  }

  size_t size() const {
    return static_cast<size_t>(size_impl(root));
  }

  class Iterator {

  public:
    Iterator()
      : root(nullptr), current_node(nullptr) {}

    std::pair<const Key, Value> &operator*() const {
      return current_node->datum;
    }

    std::pair<const Key, Value> *operator->() const {
      return &current_node->datum;
    }

    Iterator &operator++() {
      if (current_node->right) {
        current_node = min_element_impl(current_node->right);
      }
      else {
        current_node = min_greater_than_impl(root, current_node->datum.first, less);
      }
      return *this;
    }

    Iterator operator++(int) {
      Iterator result(*this);
      ++(*this);
      return result;
    }

    bool operator==(const Iterator &rhs) const {
      return current_node == rhs.current_node;
    }

    bool operator!=(const Iterator &rhs) const {
      return current_node != rhs.current_node;
    }

  private:
    friend class Map;

    Node *root;
    Node *current_node;
    Compare less;

    Iterator(Node *root_in, Node* current_node_in, Compare less_in)
      : root(root_in), current_node(current_node_in), less(less_in) { }

  };

  Iterator begin() const {
    if (root == nullptr) {
      return Iterator();
    }
    return Iterator(root, min_element_impl(root), less);
  }

  Iterator end() const {
    return Iterator();
  }

  Iterator min_element() const {
    return Iterator(root, min_element_impl(root), less);
  }

  Iterator max_element() const {
    return Iterator(root, max_element_impl(root), less);
  }

  Iterator min_greater_than(const Key &value) const {
    return Iterator(root, min_greater_than_impl(root, value, less), less);
  }

  Iterator find(const Key &query) const {
    return Iterator(root, find_impl(root, query, less), less);
  }

  std::pair<Iterator, bool> insert(const std::pair<const Key, Value> &item) {
    assert(find(item.first) == end());
    root = insert_impl(root, item, less);
    return {find(item.first), true};
  }

  Value &operator[](const Key &key) {
    Iterator it = find(key);
    if (it == end()) {
      it = insert({key, Value()}).first;
    }
    return (*it).second;
  }

private:

  Node *root;

  Compare less;

  static bool empty_impl(const Node *node) {
      return (!node);
  }

  static int size_impl(const Node *node) {
      if (!node)    {
          return 0;
      } else    {
          return 1 + size_impl(node->left) + size_impl(node->right);
      }
  }

  static Node *copy_nodes_impl(Node *node) {
      if (!node)    {
          return nullptr;
      } else    {
          Node *copy_node = new Node(node->datum, copy_nodes_impl(node->left),
                              copy_nodes_impl(node->right));
          return copy_node;
      }
  }

  static void destroy_nodes_impl(Node *node) {
      if (node) {
          destroy_nodes_impl(node->left);
          destroy_nodes_impl(node->right);
          delete node;
      }
  }

  static Node * find_impl(Node *node, const Key &query, Compare less) {
      if (!node)    {
          return nullptr;
      }
      if (less(query, node->datum.first))    {
          return find_impl(node->left, query, less);
      } else if (less(node->datum.first, query)) {
          return find_impl(node->right, query, less);
      } else    {
          return node;
      }
  }

  static Node * insert_impl(Node *node, const std::pair<const Key, Value> &item, Compare less) {
      if (!node)    {
          Node *new_node = new Node(item, nullptr, nullptr);
          return new_node;
      }
      if (less(item.first, node->datum.first))  {
          node->left = insert_impl(node->left, item, less);
      } else    {
          node->right = insert_impl(node->right, item, less);
      }
      return node;
  }

  static Node * min_element_impl(Node *node) {
      if (!node)    {
          return nullptr;
      }
      if (!node->left)  {
          return node;
      } else    {
          return min_element_impl(node->left);
      }
  }

  static Node * max_element_impl(Node *node) {
      if (!node)    {
          return nullptr;
      }
      if (!node->right)  {
          return node;
      } else    {
          return max_element_impl(node->right);
      }
  }

  static Node * min_greater_than_impl(Node *node, const Key &val, Compare less) {
      if (!node)    {
          return nullptr;
      }
      else if (!less(val, node->datum.first)) {
          return min_greater_than_impl(node->right, val, less);
      }
      else  {
          Node *left_check = min_greater_than_impl(node->left, val, less);
          if (!left_check)  {
              return node;
          } else    {
              return left_check;
          }
      }
      return nullptr;
  }


};

#endif