Post-classifier-using-Machine-Learning / src / BinarySearchTree.h
BinarySearchTree.h
Raw
#ifndef BINARY_SEARCH_TREE_H
#define BINARY_SEARCH_TREE_H

#include <cassert>  //assert
#include <iostream> //ostream
#include <functional> //less

template <typename T,
          typename Compare=std::less<T> // default if argument isn't provided
         >
class BinarySearchTree {

  struct Node {

    Node() {}

    Node(const T &datum_in, Node *left_in, Node *right_in)
            : datum(datum_in), left(left_in), right(right_in) { }

    T datum;
    Node *left;
    Node *right;
  };

public:

  BinarySearchTree()
    : root(nullptr) { }

  BinarySearchTree(const BinarySearchTree &other)
    : root(copy_nodes_impl(other.root)) { }

  BinarySearchTree &operator=(const BinarySearchTree &rhs) {
    if (this == &rhs) {
      return *this;
    }
    destroy_nodes_impl(root);
    root = copy_nodes_impl(rhs.root);
    return *this;
  }

  ~BinarySearchTree() {
    destroy_nodes_impl(root);
  }

  bool empty() const {
    return empty_impl(root);
  }

  size_t height() const {
    return static_cast<size_t>(height_impl(root));
  }

  size_t size() const {
    return static_cast<size_t>(size_impl(root));
  }

  void traverse_inorder(std::ostream &os) const {
    traverse_inorder_impl(root, os);
  }

  void traverse_preorder(std::ostream &os) const {
    traverse_preorder_impl(root, os);
  }

  bool check_sorting_invariant() const {
    return check_sorting_invariant_impl(root, less);
  }

  class Iterator {

  public:
    Iterator()
      : root(nullptr), current_node(nullptr) {}

    T &operator*() const {
      return current_node->datum;
    }

    T *operator->() const {
      return &current_node->datum;
    }

    Iterator &operator++() {
      if (current_node->right) {
        current_node = min_element_impl(current_node->right);
      }
      else {
        current_node = min_greater_than_impl(root, current_node->datum, less);
      }
      return *this;
    }

    Iterator operator++(int) {
      Iterator result(*this);
      ++(*this);
      return result;
    }

    bool operator==(const Iterator &rhs) const {
      return current_node == rhs.current_node;
    }

    bool operator!=(const Iterator &rhs) const {
      return current_node != rhs.current_node;
    }

  private:
    friend class BinarySearchTree;

    Node *root;
    Node *current_node;
    Compare less;

    Iterator(Node *root_in, Node* current_node_in, Compare less_in)
      : root(root_in), current_node(current_node_in), less(less_in) { }

  };

  Iterator begin() const {
    if (root == nullptr) {
      return Iterator();
    }
    return Iterator(root, min_element_impl(root), less);
  }

  Iterator end() const {
    return Iterator();
  }

  Iterator min_element() const {
    return Iterator(root, min_element_impl(root), less);
  }

  Iterator max_element() const {
    return Iterator(root, max_element_impl(root), less);
  }

  Iterator min_greater_than(const T &value) const {
    return Iterator(root, min_greater_than_impl(root, value, less), less);
  }

  Iterator find(const T &query) const {
    return Iterator(root, find_impl(root, query, less), less);
  }

  Iterator insert(const T &item) {
    assert(find(item) == end());
    root = insert_impl(root, item, less);
    return find(item);
  }

  std::string to_string() const;

private:

  Node *root;

  Compare less;

  class Tree_grid_square;
  class Tree_grid;

  int get_max_elt_width() const;

  static bool empty_impl(const Node *node) {
      return (!node);
  }

  static int size_impl(const Node *node) {
      if (!node)    {
          return 0;
      } else    {
          return 1 + size_impl(node->left) + size_impl(node->right);
      }
  }

  static int height_impl(const Node *node) {
      if (!node)    {
          return 0;
      } else    {
          if (height_impl(node->left) > height_impl(node->right))    {
              return 1 + height_impl(node->left);
          } else    {
              return 1 + height_impl(node->right);
          }
      }
  }

  static Node *copy_nodes_impl(Node *node) {
      if (!node)    {
          return nullptr;
      } else    {
          Node *copy_node = new Node(node->datum, copy_nodes_impl(node->left),
                              copy_nodes_impl(node->right));
          return copy_node;
      }
  }

  static void destroy_nodes_impl(Node *node) {
      if (node) {
          destroy_nodes_impl(node->left);
          destroy_nodes_impl(node->right);
          delete node;
      }
  }

  static Node * find_impl(Node *node, const T &query, Compare less) {
      if (!node)    {
          return nullptr;
      }
      if (less(query, node->datum))    {
          return find_impl(node->left, query, less);
      } else if (less(node->datum, query)) {
          return find_impl(node->right, query, less);
      } else    {
          return node;
      }
  }

  static Node * insert_impl(Node *node, const T &item, Compare less) {
      if (!node)    {
          Node *new_node = new Node(item, nullptr, nullptr);
          return new_node;
      }
      if (less(item, node->datum))  {
          node->left = insert_impl(node->left, item, less);
      } else    {
          node->right = insert_impl(node->right, item, less);
      }
      return node;
  }

  static Node * min_element_impl(Node *node) {
      if (!node)    {
          return nullptr;
      }
      if (!node->left)  {
          return node;
      } else    {
          return min_element_impl(node->left);
      }
  }

  static Node * max_element_impl(Node *node) {
      if (!node)    {
          return nullptr;
      }
      if (!node->right)  {
          return node;
      } else    {
          return max_element_impl(node->right);
      }
  }

  static bool check_sorting_invariant_impl(const Node *node, Compare less) {
      if (!node)    {
          return true;
      }
      else if (!node->right && !node->left) {
          return true;
      }
      else if (!node->right || !node->left) {
          return false;
      }
      else if (less(node->datum, (node->left)->datum)
                || less((node->right)->datum, node->datum))  {
          return false;
      }
      Node *left_max = max_element_impl(node->left);
      Node *right_min = min_element_impl(node->right);
      
      if (!less(left_max->datum, node->datum) || !less(node->datum, right_min->datum))  {
          return false;
      }
      
      bool lft = check_sorting_invariant_impl(node->left, less);
      bool rht = check_sorting_invariant_impl(node->right, less);
      
      return (lft && rht);
  }

  static void traverse_inorder_impl(const Node *node, std::ostream &os) {
      if (node) {
          traverse_inorder_impl(node->left, os);
          os << node->datum << " ";
          traverse_inorder_impl(node->right, os);
      }
  }

  static void traverse_preorder_impl(const Node *node, std::ostream &os) {
      if (node) {
          os << node->datum << " ";
          traverse_inorder_impl(node->left, os);
          traverse_inorder_impl(node->right, os);
      }
  }

  static Node * min_greater_than_impl(Node *node, const T &val, Compare less) {
      if (!node)    {
          return nullptr;
      }
      else if (!less(val, node->datum)) {
          return min_greater_than_impl(node->right, val, less);
      }
      else  {
          Node *left_check = min_greater_than_impl(node->left, val, less);
          if (!left_check)  {
              return node;
          } else    {
              return left_check;
          }
      }
      return nullptr;
  }


};

#include "TreePrint.h"

template <typename T, typename Compare>
std::ostream &operator<<(std::ostream &os,
                         const BinarySearchTree<T, Compare> &tree) {
  os << "[ ";
  for (T& elt : tree) {
    os << elt << " ";
  }
  return os << "]";
}

#endif