Self-Balancing-Binary-Trees / src / unit_tests.rs
unit_tests.rs
Raw
use crate::binary_tree_ops::{BinaryTreeOps, NodeOps, TreeNode, NodeColor};
// use crate::red_black_tree;
// use crate::avl_tree;
use crate::red_black_tree::RedBlackTreeImpl;
use crate::avl_tree::AVLTreeImpl;
// use crate::binary_tree_ops::BinaryTreeOps;
// use crate::binary_tree_ops::NodeOps;
// use crate::binary_tree_ops::TreeNode;
// use crate::binary_tree_ops::NodeColor;
use std::fmt::Debug;

#[allow(dead_code)]
fn check_red_black_properties<T: Ord + Clone + Debug>(tree: &RedBlackTreeImpl<T>) {
    if let Some(root) = tree.root_node() {
        assert_eq!(root.borrow().color(), Some(NodeColor::Black), "Root must be black in a RB tree");
    }

    fn check_no_double_red<T: Ord + Clone + Debug>(node: &Option<std::rc::Rc<std::cell::RefCell<TreeNode<T>>>>) {
        if let Some(n) = node {
            let color = n.borrow().color();
            let parent_color = n.borrow().parent.as_ref().map(|p| p.borrow().color()).flatten();
            if color == Some(NodeColor::Red) {
                assert_ne!(parent_color, Some(NodeColor::Red), "No red node should have a red parent in RB tree");
            }
            check_no_double_red(&n.borrow().left);
            check_no_double_red(&n.borrow().right);
        }
    }

    fn count_black_heights<T: Ord + Clone + Debug>(
        node: &Option<std::rc::Rc<std::cell::RefCell<TreeNode<T>>>>,
        current_black_count: usize,
        black_counts: &mut Vec<usize>,
    ) {
        if let Some(n) = node {
            let mut count = current_black_count;
            if n.borrow().color() == Some(NodeColor::Black) {
                count += 1;
            }

            let left = &n.borrow().left;
            let right = &n.borrow().right;
            if left.is_none() && right.is_none() {
                // Reached a node with only NIL children
                black_counts.push(count);
            } else {
                count_black_heights(left, count, black_counts);
                count_black_heights(right, count, black_counts);
            }
        } else {
            // For a None (NIL leaf), it is considered black but doesn't increase black_count
            black_counts.push(current_black_count);
        }
    }

    check_no_double_red(&tree.root_node());

    let mut black_counts = Vec::new();
    count_black_heights(&tree.root_node(), 0, &mut black_counts);
    println!("black_counts: {:?}", black_counts);
    
    if !black_counts.is_empty() {
        let first = black_counts[0];
        for &bc in &black_counts[1..] {
            assert_eq!(
                bc, first,
                "All root-to-leaf paths must have the same number of black nodes in a Red-Black Tree"
            );
        }
    }
}

#[allow(dead_code)]
fn check_avl_balance<T: Ord + Clone + Debug>(tree: &AVLTreeImpl<T>) {
    fn compute_height<T: Ord + Clone + Debug>(node: &Option<std::rc::Rc<std::cell::RefCell<TreeNode<T>>>>) -> i32 {
        node.as_ref().map_or(0, |n| {
            let nb = n.borrow();
            let left_h = compute_height(&nb.left);
            let right_h = compute_height(&nb.right);
            std::cmp::max(left_h, right_h) + 1
        })
    }

    fn check_balance<T: Ord + Clone + Debug>(node: &Option<std::rc::Rc<std::cell::RefCell<TreeNode<T>>>>) {
        if let Some(n) = node {
            let nb = n.borrow();
            let left_h = compute_height(&nb.left);
            let right_h = compute_height(&nb.right);
            let diff = (left_h - right_h).abs();
            assert!(diff <= 1, "AVL tree node is not balanced");

            check_balance(&nb.left);
            check_balance(&nb.right);
        }
    }

    check_balance(&tree.root_node());
}

#[test]
fn test_red_black_insert_search() {
    let mut rb_tree = RedBlackTreeImpl::new();
    let values = vec![35, 10, 20, 30, 15, 25, 5];

    for &v in &values {
        rb_tree.insert(v);
        println!("\nRB Tree Visualization:");
        rb_tree.visualize();
        check_red_black_properties(&rb_tree);
    }

    for &v in &values {
        let node = rb_tree.find_node(&v);
        assert!(node.is_some(), "Value {} should be found in RB tree after insertion", v);
    }
}

#[test]
fn test_red_black_delete() {
    let mut rb_tree = RedBlackTreeImpl::new();
    let values = vec![35, 10, 20, 30, 15, 25, 5];

    for &v in &values {
        rb_tree.insert(v);
        println!("\nRB Tree Visualization:");
        rb_tree.visualize();
        check_red_black_properties(&rb_tree);
    }

    for &v in &[20, 5, 35] {
        rb_tree.delete(v);
        println!("\nRB Tree Visualization:");
        rb_tree.visualize();
        assert!(rb_tree.find_node(&v).is_none(), "Value {} should be deleted from RB tree", v);
        check_red_black_properties(&rb_tree);
    }

    for &v in &[10, 30, 15, 25] {
        assert!(rb_tree.find_node(&v).is_some(), "Value {} should still be in RB tree", v);
    }

    for &v in &[10, 30, 15, 25] {
        rb_tree.delete(v);
        println!("\nRB Tree Visualization:");
        rb_tree.visualize();
        assert!(rb_tree.find_node(&v).is_none(), "Value {} should be deleted from RB tree", v);
        check_red_black_properties(&rb_tree); /*** FAILS ***/
    }

    assert!(rb_tree.is_empty(), "RB tree should be empty");
}

#[test]
fn test_avl_insert_search() {
    let mut avl_tree = AVLTreeImpl::new();
    let values = vec![10, 20, 30, 15, 25, 5];

    for &v in &values {
        avl_tree.insert(v);
        check_avl_balance(&avl_tree);
    }

    for &v in &values {
        let node = avl_tree.find_node(&v);
        assert!(node.is_some(), "Value {} should be found in AVL tree after insertion", v);
    }
}

#[test]
fn test_avl_delete() {
    let mut avl_tree = AVLTreeImpl::new();
    let values = vec![35, 10, 20, 30, 15, 25, 5];
    
    for &v in &values {
        avl_tree.insert(v);
        check_avl_balance(&avl_tree);
    }

    for &v in &[20, 5, 35] {
        avl_tree.delete(v);
        assert!(avl_tree.find_node(&v).is_none(), "Value {} should be deleted from AVL tree", v);
        check_avl_balance(&avl_tree);
    }

    for &v in &[10, 30, 15, 25] {
        assert!(avl_tree.find_node(&v).is_some(), "Value {} should still be in AVL tree", v);
    }

    for &v in &[10, 30, 15, 25] {
        avl_tree.delete(v);
        assert!(avl_tree.find_node(&v).is_none(), "Value {} should be deleted from AVL tree", v);
        check_avl_balance(&avl_tree);
    }

    assert!(avl_tree.is_empty(), "AVL tree should be empty");
}

#[test]
fn test_all_avl_insert_scenarios() {
    // LL imbalance
    let mut avl_ll = AVLTreeImpl::new();
    for &v in &[10, 5, 2] {
        avl_ll.insert(v);
        println!("AVL LL Tree Visualization:");
        avl_ll.visualize();
        assert!(avl_ll.find_node(&v).is_some(), "{} should be in AVL tree", v);
        check_avl_balance(&avl_ll);
    }

    println!("\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n");
    
    // RR imbalance
    let mut avl_rr = AVLTreeImpl::new();
    for &v in &[2, 5, 10] {
        avl_rr.insert(v);
        println!("AVL RR Tree Visualization:");
        avl_rr.visualize();
        assert!(avl_rr.find_node(&v).is_some(), "{} should be in AVL tree", v);
        check_avl_balance(&avl_rr);
    }

    println!("\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n");

    // LR imbalance
    let mut avl_lr = AVLTreeImpl::new();
    for &v in &[10, 5, 8] {
        avl_lr.insert(v);
        println!("AVL LR Tree Visualization:");
        avl_lr.visualize();
        assert!(avl_lr.find_node(&v).is_some(), "{} should be in AVL tree", v);
        check_avl_balance(&avl_lr);
    }

    println!("\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n");
    
    // RL imbalance
    let mut avl_rl = AVLTreeImpl::new();
    for &v in &[5, 10, 8] {
        avl_rl.insert(v);
        println!("AVL RL Tree Visualization:");
        avl_rl.visualize();
        assert!(avl_rl.find_node(&v).is_some(), "{} should be in AVL tree", v);
        check_avl_balance(&avl_rl);
    }

    println!("\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n");

    // Stess test
    let mut avl_stress = AVLTreeImpl::new();
    for &v in &[30, 20, 40, 10, 25, 5, 50, 60, 55] {
        avl_stress.insert(v);
        println!("AVL Stress Visualization:");
        avl_stress.visualize();
        assert!(avl_stress.find_node(&v).is_some(), "{} should be in AVL tree", v);
        check_avl_balance(&avl_stress);
    }
}

#[test]
fn test_all_avl_delete_scenarios() {
    // LL imbalance
    let mut avl_ll = AVLTreeImpl::new();
    for &v in &[30, 20, 40, 10, 25] {
        avl_ll.insert(v);
        println!("AVL LL Tree Visualization:");
        avl_ll.visualize();
        assert!(avl_ll.find_node(&v).is_some(), "{} should be in AVL tree", v);
        check_avl_balance(&avl_ll);
    }
    avl_ll.delete(40);
    println!("AVL LL Tree Visualization:");
    avl_ll.visualize();
    assert!(avl_ll.find_node(&40).is_none(), "40 should be deleted from AVL tree");
    check_avl_balance(&avl_ll);
    
    println!("\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n");
    
    // RR imbalance
    let mut avl_rr = AVLTreeImpl::new();
    for &v in &[30, 20, 40, 35, 45] {
        avl_rr.insert(v);
        println!("AVL RR Tree Visualization:");
        avl_rr.visualize();
        assert!(avl_rr.find_node(&v).is_some(), "{} should be in AVL tree", v);
        check_avl_balance(&avl_rr);
    }
    avl_rr.delete(20);
    println!("AVL RR Tree Visualization:");
    avl_rr.visualize();
    assert!(avl_rr.find_node(&20).is_none(), "20 should be deleted from AVL tree");
    check_avl_balance(&avl_rr);
    
    println!("\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n");
    
    // LR imbalance
    let mut avl_lr = AVLTreeImpl::new();
    for &v in &[20, 30, 10, 5, 15, 25, 12] {
        avl_lr.insert(v);
        println!("AVL LR Tree Visualization:");
        avl_lr.visualize();
        assert!(avl_lr.find_node(&v).is_some(), "{} should be in AVL tree", v);
        check_avl_balance(&avl_lr);
    }
    avl_lr.delete(25);
    println!("AVL LR Tree Visualization:");
    avl_lr.visualize();
    assert!(avl_lr.find_node(&25).is_none(), "25 should be deleted from AVL tree");
    check_avl_balance(&avl_lr);
    
    println!("\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n");
    
    // RL imbalance
    let mut avl_rl = AVLTreeImpl::new();
    for &v in &[20, 10, 30, 15, 25, 40, 28] {
        avl_rl.insert(v);
        println!("AVL RL Tree Visualization:");
        avl_rl.visualize();
        assert!(avl_rl.find_node(&v).is_some(), "{} should be in AVL tree", v);
        check_avl_balance(&avl_rl);
    }
    avl_rl.delete(15);
    println!("AVL RL Tree Visualization:");
    avl_rl.visualize();
    assert!(avl_rl.find_node(&15).is_none(), "15 should be deleted from AVL tree");
    check_avl_balance(&avl_rl);
    
    println!("\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n");
    
    // Stess test
    let mut avl_stress = AVLTreeImpl::new();
    for &v in &[50, 30, 70, 20, 40, 60, 80, 10] {
        avl_stress.insert(v);
        println!("AVL Stress Tree Visualization:");
        avl_stress.visualize();
        assert!(avl_stress.find_node(&v).is_some(), "{} should be in AVL tree", v);
        check_avl_balance(&avl_stress);
    }
    for &v in &[80, 70, 60, 50] {
        avl_stress.delete(v);
        println!("AVL Stress Tree Visualization:");
        avl_stress.visualize();
        assert!(avl_stress.find_node(&v).is_none(), "{} should be deleted from AVL tree", v);
        check_avl_balance(&avl_stress);
    }
}