Self-Balancing-Binary-Trees / src / avl_tree.rs
avl_tree.rs
Raw
use std::cell::RefCell;
use std::rc::Rc;
use std::fmt::Debug;
use crate::binary_tree_ops::{BinaryTreeOps, NodeOps, TreeNode};

pub struct AVLTreeImpl<T: Ord + Clone + Debug> {
    root: Option<Rc<RefCell<TreeNode<T>>>>,
}

impl<T: Ord + Clone + Debug> AVLTreeImpl<T> {
    pub fn new() -> Self {
        Self { root: None }
    }

    pub fn insert(&mut self, key: T) {
        let old_root = self.root.take();
        self.root = self.insert_recursive(old_root, key);
        if let Some(r) = &self.root {
            r.borrow_mut().parent = None;
        }
    }

    fn insert_recursive(&self, node: Option<Rc<RefCell<TreeNode<T>>>>, key: T) -> Option<Rc<RefCell<TreeNode<T>>>> {
        if let Some(n) = node {
            let n_key = n.borrow().key.clone();
            if key < n_key {
                let left = self.insert_recursive(n.borrow().left.clone(), key);
                n.borrow_mut().left = left;
                if let Some(l) = &n.borrow().left {
                    l.borrow_mut().parent = Some(n.clone());
                }
            } else if key > n_key {
                let right = self.insert_recursive(n.borrow().right.clone(), key);
                n.borrow_mut().right = right;
                if let Some(r) = &n.borrow().right {
                    r.borrow_mut().parent = Some(n.clone());
                }
            }
            Some(self.balance(n))
        } else {
            let new_node = TreeNode::new(key);
            new_node.borrow_mut().set_height(Some(1));
            Some(new_node)
        }
    }

    pub fn delete(&mut self, key: T) {
        let old_root = self.root.take();
        self.root = self.delete_recursive(old_root, key);
        if let Some(r) = &self.root {
            r.borrow_mut().parent = None;
        }
    }

    fn delete_recursive(&self, node: Option<Rc<RefCell<TreeNode<T>>>>, key: T) -> Option<Rc<RefCell<TreeNode<T>>>> {
        if let Some(n) = node {
            let n_key = n.borrow().key.clone();
            if key < n_key {
                let left = self.delete_recursive(n.borrow().left.clone(), key);
                n.borrow_mut().left = left;
                if let Some(l) = &n.borrow().left {
                    l.borrow_mut().parent = Some(n.clone());
                }
                Some(self.rebalance(n))
            } else if key > n_key {
                let right = self.delete_recursive(n.borrow().right.clone(), key);
                n.borrow_mut().right = right;
                if let Some(r) = &n.borrow().right {
                    r.borrow_mut().parent = Some(n.clone());
                }
                Some(self.rebalance(n))
            } else {
                let left = n.borrow().left.clone();
                let right = n.borrow().right.clone();
                if left.is_none() {
                    return right;
                } else if right.is_none() {
                    return left;
                } else {
                    let successor = self.find_min(right.unwrap());
                    n.borrow_mut().key = successor.borrow().key.clone();
                    let new_right = self.delete_recursive(n.borrow().right.clone(), successor.borrow().key.clone());
                    n.borrow_mut().right = new_right;
                    if let Some(r) = &n.borrow().right {
                        r.borrow_mut().parent = Some(n.clone());
                    }
                    Some(self.rebalance(n))
                }
            }
        } else {
            None
        }
    }

    fn rotate_right(&self, y: Rc<RefCell<TreeNode<T>>>) -> Rc<RefCell<TreeNode<T>>> {
        let x = y.borrow_mut().left.take().unwrap();
        {
            let mut yb = y.borrow_mut();
            let xr = x.borrow_mut().right.take();
            yb.left = xr;
            if let Some(l) = &yb.left {
                l.borrow_mut().parent = Some(y.clone());
            }
        }
        x.borrow_mut().parent = y.borrow().parent.clone();
        self.update_height(&y);
        self.update_height(&x);
        x.borrow_mut().right = Some(y.clone());
        y.borrow_mut().parent = Some(x.clone());
        x
    }

    fn rotate_left(&self, x: Rc<RefCell<TreeNode<T>>>) -> Rc<RefCell<TreeNode<T>>> {
        let y = x.borrow_mut().right.take().unwrap();
        {
            let mut xb = x.borrow_mut();
            let yl = y.borrow_mut().left.take();
            xb.right = yl;
            if let Some(r) = &xb.right {
                r.borrow_mut().parent = Some(x.clone());
            }
        }
        y.borrow_mut().parent = x.borrow().parent.clone();
        self.update_height(&x);
        self.update_height(&y);
        y.borrow_mut().left = Some(x.clone());
        x.borrow_mut().parent = Some(y.clone());
        y
    }

    fn compute_height(node: &Option<Rc<RefCell<TreeNode<T>>>>) -> i32 {
        node.as_ref().map_or(0, |n| n.borrow().height().unwrap_or(1))
    }

    fn update_height(&self, node: &Rc<RefCell<TreeNode<T>>>) {
        let left_height = Self::compute_height(&node.borrow().left);
        let right_height = Self::compute_height(&node.borrow().right);
        node.borrow_mut().set_height(Some(1 + std::cmp::max(left_height, right_height)));
    }

    fn balance_factor(node: &Option<Rc<RefCell<TreeNode<T>>>>) -> i32 {
        if let Some(n) = node {
            let left_h = Self::compute_height(&n.borrow().left);
            let right_h = Self::compute_height(&n.borrow().right);
            left_h - right_h
        } else {
            0
        }
    }

    #[allow(dead_code)]
    fn balance_old(&self, node: Rc<RefCell<TreeNode<T>>>) -> Rc<RefCell<TreeNode<T>>> {
        self.update_height(&node);
        let bf = Self::balance_factor(&Some(node.clone()));
        if bf > 1 {
            if Self::balance_factor(&node.borrow().left) < 0 {
                let left_child = node.borrow().left.as_ref().unwrap().clone();
                let rotated = self.rotate_left(left_child);
                node.borrow_mut().left = Some(rotated);
                if let Some(l) = &node.borrow().left {
                    l.borrow_mut().parent = Some(node.clone());
                }
            }
            return self.rotate_right(node);
        } else if bf < -1 {
            if Self::balance_factor(&node.borrow().right) > 0 {
                let right_child = node.borrow().right.as_ref().unwrap().clone();
                let rotated = self.rotate_right(right_child);
                node.borrow_mut().right = Some(rotated);
                if let Some(r) = &node.borrow().right {
                    r.borrow_mut().parent = Some(node.clone());
                }
            }
            return self.rotate_left(node);
        }
        node
    }

    fn balance(&self, node: Rc<RefCell<TreeNode<T>>>) -> Rc<RefCell<TreeNode<T>>> {
        self.update_height(&node);
        let bf = Self::balance_factor(&Some(node.clone()));
        
        if bf > 1 {
            // Left Heavy
            if Self::balance_factor(&node.borrow().left) < 0 {
                // Left-Right Case
                let left_child = node.borrow().left.as_ref().unwrap().clone();
                let rotated = self.rotate_left(left_child);
                node.borrow_mut().left = Some(rotated);
                if let Some(l) = &node.borrow().left {
                    l.borrow_mut().parent = Some(node.clone());
                }
            }
            let rotated_node = self.rotate_right(node);
            self.update_height(&rotated_node);
            rotated_node
        } else if bf < -1 {
            // Right Heavy
            if Self::balance_factor(&node.borrow().right) > 0 {
                // Right-Left Case
                let right_child = node.borrow().right.as_ref().unwrap().clone();
                let rotated = self.rotate_right(right_child);
                node.borrow_mut().right = Some(rotated);
                if let Some(r) = &node.borrow().right {
                    r.borrow_mut().parent = Some(node.clone());
                }
            }
            let rotated_node = self.rotate_left(node);
            self.update_height(&rotated_node);
            rotated_node
        } else {
            // Balanced
            // Even if no rotations occur, ensure height is accurate.
            self.update_height(&node);
            node
        }
    }

    fn rebalance(&self, n: Rc<RefCell<TreeNode<T>>>) -> Rc<RefCell<TreeNode<T>>> {
        self.balance(n)
    }

    fn find_min(&self, start_node: Rc<RefCell<TreeNode<T>>>) -> Rc<RefCell<TreeNode<T>>> {
        let mut node = start_node;
        while let Some(left) = {
            let nb = node.borrow();
            nb.left.clone()
        } {
            node = left;
        }
        node
    }
}

impl<T: Ord + Clone + Debug> BinaryTreeOps<T> for AVLTreeImpl<T> {
    fn root_node(&self) -> Option<Rc<RefCell<TreeNode<T>>>> {
        self.root.clone()
    }

    fn set_root_node(&mut self, root: Option<Rc<RefCell<TreeNode<T>>>>) {
        self.root = root;
        if let Some(r) = &self.root {
            r.borrow_mut().parent = None;
        }
    }
}