#ifndef BINARY_SEARCH_TREE_H
#define BINARY_SEARCH_TREE_H
#include <cassert> //assert
#include <iostream> //ostream
#include <functional> //less
template <typename T,
typename Compare=std::less<T> // default if argument isn't provided
>
class BinarySearchTree {
struct Node {
Node() {}
Node(const T &datum_in, Node *left_in, Node *right_in)
: datum(datum_in), left(left_in), right(right_in) { }
T datum;
Node *left;
Node *right;
};
public:
BinarySearchTree()
: root(nullptr) { }
BinarySearchTree(const BinarySearchTree &other)
: root(copy_nodes_impl(other.root)) { }
BinarySearchTree &operator=(const BinarySearchTree &rhs) {
if (this == &rhs) {
return *this;
}
destroy_nodes_impl(root);
root = copy_nodes_impl(rhs.root);
return *this;
}
~BinarySearchTree() {
destroy_nodes_impl(root);
}
bool empty() const {
return empty_impl(root);
}
size_t height() const {
return static_cast<size_t>(height_impl(root));
}
size_t size() const {
return static_cast<size_t>(size_impl(root));
}
void traverse_inorder(std::ostream &os) const {
traverse_inorder_impl(root, os);
}
void traverse_preorder(std::ostream &os) const {
traverse_preorder_impl(root, os);
}
bool check_sorting_invariant() const {
return check_sorting_invariant_impl(root, less);
}
class Iterator {
public:
Iterator()
: root(nullptr), current_node(nullptr) {}
T &operator*() const {
return current_node->datum;
}
T *operator->() const {
return ¤t_node->datum;
}
Iterator &operator++() {
if (current_node->right) {
current_node = min_element_impl(current_node->right);
}
else {
current_node = min_greater_than_impl(root, current_node->datum, less);
}
return *this;
}
Iterator operator++(int) {
Iterator result(*this);
++(*this);
return result;
}
bool operator==(const Iterator &rhs) const {
return current_node == rhs.current_node;
}
bool operator!=(const Iterator &rhs) const {
return current_node != rhs.current_node;
}
private:
friend class BinarySearchTree;
Node *root;
Node *current_node;
Compare less;
Iterator(Node *root_in, Node* current_node_in, Compare less_in)
: root(root_in), current_node(current_node_in), less(less_in) { }
};
Iterator begin() const {
if (root == nullptr) {
return Iterator();
}
return Iterator(root, min_element_impl(root), less);
}
Iterator end() const {
return Iterator();
}
Iterator min_element() const {
return Iterator(root, min_element_impl(root), less);
}
Iterator max_element() const {
return Iterator(root, max_element_impl(root), less);
}
Iterator min_greater_than(const T &value) const {
return Iterator(root, min_greater_than_impl(root, value, less), less);
}
Iterator find(const T &query) const {
return Iterator(root, find_impl(root, query, less), less);
}
Iterator insert(const T &item) {
assert(find(item) == end());
root = insert_impl(root, item, less);
return find(item);
}
std::string to_string() const;
private:
Node *root;
Compare less;
class Tree_grid_square;
class Tree_grid;
int get_max_elt_width() const;
static bool empty_impl(const Node *node) {
return (!node);
}
static int size_impl(const Node *node) {
if (!node) {
return 0;
} else {
return 1 + size_impl(node->left) + size_impl(node->right);
}
}
static int height_impl(const Node *node) {
if (!node) {
return 0;
} else {
if (height_impl(node->left) > height_impl(node->right)) {
return 1 + height_impl(node->left);
} else {
return 1 + height_impl(node->right);
}
}
}
static Node *copy_nodes_impl(Node *node) {
if (!node) {
return nullptr;
} else {
Node *copy_node = new Node(node->datum, copy_nodes_impl(node->left),
copy_nodes_impl(node->right));
return copy_node;
}
}
static void destroy_nodes_impl(Node *node) {
if (node) {
destroy_nodes_impl(node->left);
destroy_nodes_impl(node->right);
delete node;
}
}
static Node * find_impl(Node *node, const T &query, Compare less) {
if (!node) {
return nullptr;
}
if (less(query, node->datum)) {
return find_impl(node->left, query, less);
} else if (less(node->datum, query)) {
return find_impl(node->right, query, less);
} else {
return node;
}
}
static Node * insert_impl(Node *node, const T &item, Compare less) {
if (!node) {
Node *new_node = new Node(item, nullptr, nullptr);
return new_node;
}
if (less(item, node->datum)) {
node->left = insert_impl(node->left, item, less);
} else {
node->right = insert_impl(node->right, item, less);
}
return node;
}
static Node * min_element_impl(Node *node) {
if (!node) {
return nullptr;
}
if (!node->left) {
return node;
} else {
return min_element_impl(node->left);
}
}
static Node * max_element_impl(Node *node) {
if (!node) {
return nullptr;
}
if (!node->right) {
return node;
} else {
return max_element_impl(node->right);
}
}
static bool check_sorting_invariant_impl(const Node *node, Compare less) {
if (!node) {
return true;
}
else if (!node->right && !node->left) {
return true;
}
else if (!node->right || !node->left) {
return false;
}
else if (less(node->datum, (node->left)->datum)
|| less((node->right)->datum, node->datum)) {
return false;
}
Node *left_max = max_element_impl(node->left);
Node *right_min = min_element_impl(node->right);
if (!less(left_max->datum, node->datum) || !less(node->datum, right_min->datum)) {
return false;
}
bool lft = check_sorting_invariant_impl(node->left, less);
bool rht = check_sorting_invariant_impl(node->right, less);
return (lft && rht);
}
static void traverse_inorder_impl(const Node *node, std::ostream &os) {
if (node) {
traverse_inorder_impl(node->left, os);
os << node->datum << " ";
traverse_inorder_impl(node->right, os);
}
}
static void traverse_preorder_impl(const Node *node, std::ostream &os) {
if (node) {
os << node->datum << " ";
traverse_inorder_impl(node->left, os);
traverse_inorder_impl(node->right, os);
}
}
static Node * min_greater_than_impl(Node *node, const T &val, Compare less) {
if (!node) {
return nullptr;
}
else if (!less(val, node->datum)) {
return min_greater_than_impl(node->right, val, less);
}
else {
Node *left_check = min_greater_than_impl(node->left, val, less);
if (!left_check) {
return node;
} else {
return left_check;
}
}
return nullptr;
}
};
#include "TreePrint.h"
template <typename T, typename Compare>
std::ostream &operator<<(std::ostream &os,
const BinarySearchTree<T, Compare> &tree) {
os << "[ ";
for (T& elt : tree) {
os << elt << " ";
}
return os << "]";
}
#endif