Piazza-Classifier / BinarySearchTree.h
BinarySearchTree.h
Raw
#ifndef BINARY_SEARCH_TREE_H
#define BINARY_SEARCH_TREE_H
#include <cassert>  //assert
#include <iostream> //ostream
#include <functional> //less
using namespace std;




template <typename T,
          typename Compare=std::less<T> // default if argument isn't provided
         >
class BinarySearchTree {
private:


  struct Node {


    Node() {}


    Node(const T &datum_in, Node *left_in, Node *right_in)
            : datum(datum_in), left(left_in), right(right_in) { }

    T datum;
    Node *left;
    Node *right;
  };

public:


  BinarySearchTree()
    : root(nullptr) { }

  // Copy constructor
  BinarySearchTree(const BinarySearchTree &other)
    : root(copy_nodes_impl(other.root)) { }

  // Assignment operator
  BinarySearchTree &operator=(const BinarySearchTree &rhs) {
    if (this == &rhs) {
      return *this;
    }
    destroy_nodes_impl(root);
    root = copy_nodes_impl(rhs.root);
    return *this;
  }

  ~BinarySearchTree() {
    // cout << "DESTRUCTOR FLAG " << endl;
    destroy_nodes_impl(root);
  }

  bool empty() const {
    return empty_impl(root);
  }

  size_t height() const {
    return static_cast<size_t>(height_impl(root));
  }

  size_t size() const {
    return static_cast<size_t>(size_impl(root));
  }


  void traverse_inorder(std::ostream &os) const {
    traverse_inorder_impl(root, os);
  }


  void traverse_preorder(std::ostream &os) const {
    traverse_preorder_impl(root, os);
  }


  bool check_sorting_invariant() const {
    return check_sorting_invariant_impl(root, less);
  }

  class Iterator {


  public:
    Iterator()
      : root(nullptr), current_node(nullptr) {}


    T &operator*() const {
      return current_node->datum;
    }


    T *operator->() const {
      return &current_node->datum;
    }

    Iterator &operator++() {
      if (current_node->right) {
        // If has right child, next element is minimum of right subtree
        current_node = min_element_impl(current_node->right);
      }
      else {
        // Otherwise, look in the whole tree for the next biggest element
        current_node = min_greater_than_impl(root, current_node->datum, less);
      }
      return *this;
    }


    Iterator operator++(int) {
      Iterator result(*this);
      ++(*this);
      return result;
    }

    bool operator==(const Iterator &rhs) const {
      return current_node == rhs.current_node;
    }

    bool operator!=(const Iterator &rhs) const {
      return current_node != rhs.current_node;
    }

  private:
    friend class BinarySearchTree;

    Node *root;
    Node *current_node;
    Compare less;

    Iterator(Node *root_in, Node* current_node_in, Compare less_in)
      : root(root_in), current_node(current_node_in), less(less_in) { }

  }; 

  Iterator begin() const {
    if (root == nullptr) {
      return Iterator();
    }
    return Iterator(root, min_element_impl(root), less);
  }

  Iterator end() const {
    return Iterator();
  }



  Iterator min_element() const {
    return Iterator(root, min_element_impl(root), less);
  }


  Iterator max_element() const {
    return Iterator(root, max_element_impl(root), less);
  }


  Iterator min_greater_than(const T &value) const {
    return Iterator(root, min_greater_than_impl(root, value, less), less);
  }


  Iterator find(const T &query) const {
    return Iterator(root, find_impl(root, query, less), less);
  }


  Iterator insert(const T &item) {
    assert(find(item) == end());
    root = insert_impl(root, item, less);
    return find(item);
  }


  std::string to_string() const;


private:


  Node *root;

  Compare less;

    

  class Tree_grid_square;
  class Tree_grid;


  int get_max_elt_width() const;




  static bool empty_impl(const Node *node) {
    return node == nullptr;
  } 


  static int size_impl(const Node *node) {
    if(empty_impl(node)) {
      return 0;
    } else {
      return size_impl(node->right) + 1 + size_impl(node->left);
    }
  }


  static int height_impl(const Node *node) {
   if(empty_impl(node)) {
      return 0;
    } 
      return max(height_impl(node->left), height_impl(node->right)) + 1;
      // see paper w/o water on it
  }


  static Node *copy_nodes_impl(Node *node) {
    if(empty_impl(node)) {
      return nullptr;
    }
    Node *temp = new Node;
    temp->datum = node->datum;
    temp->left = copy_nodes_impl(node->left);
    temp->right = copy_nodes_impl(node->right);
    return temp;
  }

  static void destroy_nodes_impl(Node *node) {
    if(empty_impl(node))
    {
      return;
    }
    destroy_nodes_impl(node->right);
    destroy_nodes_impl(node->left);
    delete node;
    // if(node->left == nullptr && node->right == nullptr)
    // {
    //   delete node;
    //   destroy_nodes_impl(node);
    // }
    //  if(node->right == nullptr)
    // {
    //   destroy_nodes_impl(node->left);
    // }
  }


  static Node * find_impl(Node *node, const T &query, Compare less) {
    if(empty_impl(node))
      return nullptr;
    if(!less(node->datum, query) && !less(query,node->datum))
    {

      return node;
    }   
    else if(less(query, node->datum))
    {
      return find_impl(node->left, query, less);
    }
    else 
    {
      return find_impl(node->right, query, less);
    }
    return nullptr;
  }


  static Node * insert_impl(Node *node, const T &item, Compare less) {
    if(empty_impl(node)) {
        // Node *temp = new Node;
        // temp->datum = item;
        //   node = temp;
          //return temp;
          return new Node(item, nullptr, nullptr);
    } 
    
    if(less(node->datum, item)) {
      node->right = insert_impl(node->right, item, less);
     // return node->right;
    } else if (less(item, node->datum)){
       node->left = insert_impl(node->left, item, less);
      //return node->left;
    } else {
      // cout << "ALERT " << endl;
      // return node;
    }
    // else if (node->datum == item) {
    //     return node;
    // } else if (less(item, node->datum)) {
    //     return insert_impl(node->left, item, less);
    // } else { // greater
    //     return insert_impl(node->right, item, less);
    // }
     //cout << "AAAAAAAAAAA" << endl;
    return node;

    }

  // void nodeprint(Node *node)
  // {
  //   cout << node->datum << endl;
  //   nodeprint(node->right);
  // }

  static Node * min_element_impl(Node *node) {
    if(empty_impl(node))
    {
      return nullptr;
    }
    if(empty_impl(node->left)) {
      return node;
    } else {
       return min_element_impl(node->left);
    }
    return node;
  }


  static Node * max_element_impl(Node *node) {  
  
    if(empty_impl(node))
      return nullptr;
    if(empty_impl(node->right))
    {
      return node;
    } 
    else 
    {
      // Node* temp = node->right;
      return max_element_impl(node->right);
    }
  }



  static bool check_sorting_invariant_impl(const Node *node, Compare less) {
    //int ctr = 0;
    if(empty_impl(node))
      return true;
    if(!leftnull(node) && !rightnull(node))
    {
      if(less(datumreturn(node->right), node->datum) || less(node->datum, 
      datumreturn(node->left)))
        return false;
      if(less(datumreturn(max_element_impl(node->left)), node->datum) &&
        less(node->datum, datumreturn(min_element_impl(node->right))))
        {
          if(check_sorting_invariant_impl(node->left, less) && 
          check_sorting_invariant_impl(node->right, less))
            return true;
          else
            return false;
        }
      else
      {
        return false;
      }
    }
    if(!leftnull(node) && rightnull(node))
    {
      if(less(node->datum, datumreturn(node->left)))
        return false;
       if(less(datumreturn(max_element_impl(node->left)), node->datum))
       {
         return check_sorting_invariant_impl(node->left, less);
       }
       else 
        return false;
    }
    else if(!rightnull(node) && leftnull(node))
    {
       if(less( datumreturn(node->right), node->datum))
        return false;
      if(less(node->datum, datumreturn(min_element_impl(node->right))))
      {
        return check_sorting_invariant_impl(node->right, less);
      }
      else
      return false;
    }
    return true;
    }

  static void traverse_inorder_impl(const Node *node, std::ostream &os) {
    if(empty_impl(node)) {
      return;
    }
    traverse_inorder_impl(node->left, os);
    os << node->datum << " "; // << endl;
    traverse_inorder_impl(node->right, os);

  }


  static void traverse_preorder_impl(const Node *node, std::ostream &os) {
    if(empty_impl(node)) {
      return;
    }
    os << node->datum << " "; // << endl;
    traverse_preorder_impl(node->left, os);
    traverse_preorder_impl(node->right, os);
  }


  static bool leftnull(const Node *node)
  {
    return node->left == nullptr;
  }

  static bool rightnull(const Node *node)
  {
    return node->right == nullptr;
  }

  static T datumreturn(const Node *node)
  {
    return node->datum;
  }

  static bool equal(const T a, const T b, Compare less)
  {
    if(!less(a,b) && !less(b,a))
      return true;
    else 
      return false;
  }

   static Node * min_greater_than_impl(Node *node, const T &val, Compare less) {\
   // WHY IS EVERY COMMENT ALL CAPS
    if(empty_impl(node))
    {
      return nullptr;
    }
    if(less(val, node->datum))    // IF VAL IS LESS THAN ROOT AT THE TIME 
    {
      if(leftnull(node))    // NOTHING LESS THAN ROOT 
      {
        return node;
      }
      if(!less(val, datumreturn(max_element_impl(node->left))))  
      // ^ MAX EL. IN LEFT SUBTREE IS LESS THAN VAL
      {
        return node;      // MAKING US RETURN ROOT, AS THAT IS THE FIRST LARGEST ELEMENT 
      }
      else
      {
        return min_greater_than_impl(node->left, val, less);    // IF NOT, TRAVERSE LEFT
      }
    }

    else if(less(node->datum, val)) // IF VAL IS MORE THAN ROOT AT THE TIME
    {
      if(rightnull(node))     // NOTHING GREATER THAN ROOT, WHICH IS LESS THAN VAL
      {
        return nullptr;
      }
      if(!less(val, datumreturn(min_element_impl(node->right))))  
      // ^ IF MIN ELEMENT OF RIGHT SUBTREE IS SMALLER THAN VAL 
      {
        return min_greater_than_impl(node->right, val, less);   // TRAVERSE TO THE RIGHT 
      }
      else  // IF MINIMUM ELEMENT OF RIGHT SUBTREE IS GREATER THAN VAL
      {
        return min_element_impl(node->right);  // RETURN THE MIN ELEMENT OF THAT SUBTREE 
      }
    }
    else // IF VAL IS EQUAL TO NODE
    {
    if(rightnull(node))     // NOTHING GREATER THAN ROOT, WHICH IS LESS THAN VAL
       {
         return nullptr;
       }

      return min_element_impl(node->right);
    }
      return nullptr;

   }

  



}; 
#include "TreePrint.h" 



template <typename T, typename Compare>
std::ostream &operator<<(std::ostream &os,
                         const BinarySearchTree<T, Compare> &tree) {

  os << "[ ";
  for (T& elt : tree) {
    os << elt << " ";
  }
  return os << "]";
}

#endif