python-data-structures-treaps / py_treaps / treap_map.py
treap_map.py
Raw
import random
import typing
import math
from collections.abc import Iterator
from typing import List, Optional, cast

from py_treaps.treap import KT, VT, Treap
from py_treaps.treap_node import TreapNode


# Example usage found in test_treaps.py
class TreapMap(Treap[KT, VT]):
    # Add an __init__ if you want. Make the parameters optional, though.
    def __init__(self):
    # Initialize the TreapMap class by defining the root node. We add there a keys list
    # to return using the iterator and the iteration index to use in the next function
        self.root_node = None
        self.keys = []
        self.key_iter_index = 0
    def get_root_node(self) -> TreapNode:
        return self.root_node
    def lookup(self, key: KT) -> Optional[VT]:
        if self.root_node:
            current_node = self.root_node
            while True:
                if current_node.key == key:
                    return current_node.value
                else:
                    if current_node.key < key and current_node.right_child is not None:
                        current_node = current_node.right_child
                    elif current_node.key > key and current_node.left_child is not None:
                        current_node = current_node.left_child
                    else:
                        return None
        else:
            return None
                
    def insert(self, key: KT, value: VT, priority: "Optional[int]" = None) -> None:
    # Insert by first adding the node to the leaf level, then continuously rotating as the node's
    # priority is > its parents. The function returns the node's old priority if, when inserting the
    # node at leaf, a node with the same key was already in the treap. This will be used in "split"
    # to retain all node's priorities
        def insert_into_leaf(key, value):
        # Insert a node into the treap's leaf level. Some consideration was done if the treap was empty
        # or if the treap was once filled but became empty
            previous_priority = None
            if self.root_node is None: # If treap is empty
                self.root_node = TreapNode(key, value)
                if priority is not None:
                    self.root_node.priority = priority
                return self.root_node, previous_priority
            elif self.root_node.key is None: # If treap was once filled but is now empty
                self.root_node = TreapNode(key, value)
                if priority is not None:
                    self.root_node.priority = priority
                return self.root_node, previous_priority
            else: # Perform a normal BST insertion
                current_node = self.root_node
                while True:
                    if key == current_node.key:
                        previous_value = current_node.value
                        current_node.value = value
                        previous_priority = current_node.priority
                        if priority is not None:
                            current_node.priority = priority
                            current_node.value = previous_value
                        break
                    elif key < current_node.key:
                        if current_node.left_child is None:
                            current_node.left_child = TreapNode(key, value)
                            current_node.left_child.parent = current_node
                            if priority is not None:
                                current_node.left_child.priority = priority
                            current_node = current_node.left_child
                            break
                        else:
                            current_node = current_node.left_child
                    else:
                        if current_node.right_child is None:
                            current_node.right_child = TreapNode(key, value)
                            current_node.right_child.parent = current_node
                            if priority is not None:
                                current_node.right_child.priority = priority
                            current_node = current_node.right_child
                            break
                        else:
                            current_node = current_node.right_child
                return current_node, previous_priority

        current_node, previous_priority = insert_into_leaf(key, value)
        if current_node != self.root_node:
            while current_node.priority > current_node.parent.priority:
                if current_node == current_node.parent.right_child:
                    self.rotate(current_node.parent, "left")
                else:
                    self.rotate(current_node.parent, "right")
                if current_node == self.root_node:
                    break
        return previous_priority

    def rotate(self, pivot: TreapNode, direction: str) -> None:
        if direction == "right":
            # Store the pivot child's right child to be used later
            right_grandchild = pivot.left_child.right_child
            parent = pivot.parent
            # Connect the pivot's left child to the pivot's parent
            pivot.left_child.parent = pivot.parent
            if parent != None:
                if pivot == pivot.parent.left_child:
                    pivot.parent.left_child = pivot.left_child
                else:
                    pivot.parent.right_child = pivot.left_child
            # Connect the pivot to its new parent - its left child
            pivot.left_child.right_child = pivot
            pivot.parent = pivot.left_child
            # Connect the pivot to the left child's lost right child
            pivot.left_child = right_grandchild
            if right_grandchild:
                right_grandchild.parent = pivot
            # Update the root variable if the pivot was the root
            if self.root_node == pivot:
                self.root_node = pivot.parent
        else:
            # Store the pivot child's left child to be used later
            left_grandchild = pivot.right_child.left_child
            # Connect the pivot's right child to the pivot's parent
            pivot.right_child.parent = pivot.parent
            if pivot.parent:
                if pivot == pivot.parent.left_child:
                    pivot.parent.left_child = pivot.right_child
                else:
                    pivot.parent.right_child = pivot.right_child
            # Connect the pivot to its new parent - its right child
            pivot.right_child.left_child = pivot        
            pivot.parent = pivot.right_child
            # Connect the pivot to the right child's lost left child
            pivot.right_child = left_grandchild
            if left_grandchild:
                left_grandchild.parent = pivot   
            if self.root_node == pivot:
                self.root_node = pivot.parent
           
    def remove(self, key: KT) -> Optional[VT]:
        # Perform a find operation followed by a delete operation (done by rotating the node
        # until it is the leaf of the tree. After that, it is cut from the tree)
        def lookup_pointer(key: KT) -> Optional[TreapNode]:
        # Return the node to be removed if exists and None if it doesn't
            if self.root_node:
                current_node = self.root_node
                while True:
                    if current_node.key == key:
                        return current_node
                    else:
                        if current_node.key < key and current_node.right_child is not None:
                            current_node = current_node.right_child
                        elif current_node.key > key and current_node.left_child is not None:
                            current_node = current_node.left_child
                        else:
                            return None
            else:
                return None
        def delete_node(node: TreapNode) -> None:
        # Set all the relevant values in the node to be deleted to None
            node.key = None
            node.value = None
            node.priority = None
            if node.parent:
                if node.parent.left_child == node:
                    node.parent.left_child = None
                else:
                    node.parent.right_child = None
            node.parent = None
        def rotate_and_cut(node: TreapNode) -> None:
        # Keep rotating the node to be removed down the tree until it reaches leaf level.
        # After that, call the delete_node function
            while True:
                if node.right_child is None and node.left_child is None:
                    delete_node(node)
                    break
                elif node.right_child is not None and node.left_child is not None:
                    if node.right_child.priority > node.left_child.priority:
                        self.rotate(node, "left")
                    else:
                        self.rotate(node, "right")
                else:
                    if node.right_child is not None:
                        self.rotate(node, "left")
                    else:
                        self.rotate(node, "right")

        current_node = lookup_pointer(key)
        if current_node is None:
            return None
        else:
            node_value = current_node.value
            rotate_and_cut(current_node)
            return node_value

    def split(self, threshold: KT) -> "List[Treap[KT, VT]]":
        # Split treap by inserting a node with the threshold value then assigning 
        # left and right treaps to the new root node. If the threshold exists in the 
        # treap as a key, the function stores the node's priority then retains the node
        # after splitting
        threshold_priority = self.insert(threshold, None, self.MAX_PRIORITY) # Store the old priority value if threshold exists or None if it doesn't
        left_treap = TreapMap()
        right_treap = TreapMap()
        if self.root_node.left_child:
            left_treap.root_node = self.root_node.left_child
        if self.root_node.right_child:
            right_treap.root_node = self.root_node.right_child
        if threshold_priority is not None:
            right_treap.insert(threshold, self.root_node.value, threshold_priority)
        split_list = [left_treap, right_treap]
        return split_list

    def join(self, _other: "Treap[KT, VT]") -> None:
        def min_max(node: TreapNode):
        # Assess the given tree's min and max keys. To be further used to identify whether the joined tree
        # will be connected as a right or left child
            current_min = node
            current_max = node
            while current_max.right_child:
                current_max = current_max.right_child
            while current_min.left_child:
                current_min = current_min.left_child
            return current_min.key, current_max.key
        if _other.root_node and self.root_node: # If both treaps exist. This takes care of empty treaps
            treap1_min, treap1_max = min_max(self.root_node)
            treap2_min, treap2_max = min_max(_other.root_node)
            if not treap1_min > treap2_max and not treap2_min > treap1_max: # Keys in T2 have to be strictly > or < T1's keys
                raise Exception("The treap to be joined has to have keys fully > or < the current treap")
            else:
                new_root_priority = max(self.MAX_PRIORITY, _other.MAX_PRIORITY) + 1 # Assign highgest possible priority
                new_root_key = max(treap1_max, treap2_max) + 1 # Assign a key that is surely not a duplicate
                new_root = TreapNode(new_root_key, None)
                new_root.priority = new_root_priority
                if self.root_node.key > _other.root_node.key:
                    new_root.right_child = self.root_node
                    self.root_node.parent = new_root
                    new_root.left_child = _other.root_node
                    _other.root_node.parent = new_root
                else:
                    new_root.left_child = self.root_node
                    self.root_node.parent = new_root
                    new_root.right_child = _other.root_node
                    _other.root_node.parent = new_root
                self.root_node = new_root
                self.remove(new_root_key)
        elif not self.root_node and _other.root_node:
            self.root_node = _other.root_node
    def meld(self, other: "Treap[KT, VT]") -> None: # KARMA
        raise AttributeError
    def difference(self, other: "Treap[KT, VT]") -> None: # KARMA
        raise AttributeError
    def balance_factor(self) -> float: # KARMA
        raise AttributeError
    def __str__(self) -> str:
        def print_from_node(node, indent, child_type):
            if not indent == 0:
                tree_branch = "|__"
            else:
                tree_branch = ""
            print("{}{}{}[{}] <{}, {}>".format(indent*"  ", tree_branch, child_type, node.priority, node.key, node.value))
            if node.left_child is not None:
                next_line_indent_left = indent + 1
                print_from_node(node.left_child, next_line_indent_left," L ")
            if node.right_child is not None:
                next_line_indent_right = indent + 1
                print_from_node(node.right_child, next_line_indent_right," R ")
        if self.root_node:
            print_from_node(self.root_node,0,"")
        else:
            print("This treap is empty")
    def __iter__(self) -> typing.Iterator[KT]:
        # Perform in-order traversal at the initialization of the iter
        def in_order_traversal(starting_node: TreapNode) -> None:
            if starting_node.left_child:
                in_order_traversal(starting_node.left_child)
            self.keys.append(starting_node.key)
            if starting_node.right_child:
                in_order_traversal(starting_node.right_child)
        if self.root_node:
            self.keys = []
            in_order_traversal(self.root_node)
        self.key_iter_index = 0
        return self
    
    def __next__(self) -> typing.Iterator[str]:
        if self.key_iter_index < len(self.keys):
            key_iter = self.keys[self.key_iter_index]
            self.key_iter_index += 1
            return key_iter
        else:
            self.key_iter_index = 0
            raise StopIteration