import time # from sympy import im import pickle from model.partition_node import PartitionNode import numpy as np from numpy import genfromtxt import copy import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from mpl_toolkits.mplot3d import Axes3D class PartitionTree: ''' The data structure that represent the partition layout, which also maintain the parent, children relation info Designed to provide efficient online query and serialized ability The node data structure could be checked from the PartitionNode class ''' def __init__(self, num_dims = 0, boundary = []): # the node id of root should be 0, its pid should be -1 # note this initialization does not need dataset and does not set node size! self.pt_root = PartitionNode(num_dims, boundary, nid = 0, pid = -1, is_irregular_shape_parent = False, is_irregular_shape = False, num_children = 0, children_ids = [], is_leaf = True, node_size = 0) self.nid_node_dict = {0: self.pt_root} # node id to node dictionary self.node_count = 1 # the root node # self.allocations=list() # self.table_dataset_size=None # record the table data size # = = = = = public functions (API) = = = = = def save_tree(self, path): node_list = self.__generate_node_list(self.pt_root) # do we really need this step? serialized_node_list = self.__serialize_by_pickle(node_list) with open(path, "wb") as f: # perform different operations for join tree and normal tree if hasattr(self, 'join_attr'): pickle.dump((self.join_attr,self.join_depth,self.column_width,self.used_columns,serialized_node_list), f, True) else: pickle.dump((self.column_width,self.used_columns,serialized_node_list), f, True) # return serialized_node_list def load_tree(self, path): with open(path, "rb") as f: res = pickle.load(f) if hasattr(self, 'join_attr'): self.join_attr=res[0] self.join_depth=res[1] self.column_width=res[2] self.used_columns=res[3] serialized_node_list=res[4] else: self.column_width=res[0] self.used_columns=res[1] serialized_node_list=res[2] self.__build_tree_from_serialized_node_list(serialized_node_list) self.node_count=len(self.nid_node_dict.keys()) # 仅仅访问分区树top layer的节点,确定查询的大范围(即join key的范围) def query_single_toplayer(self, query): partition_ids = self.__find_overlapped_partition_consider_depth(self.pt_root, query, 1) return partition_ids def query_single(self, query, print_info = False): ''' query is in plain form, i.e., [l1,l2,...,ln, u1,u2,...,un] return the overlapped leaf partitions ids! redundant_partition: [(boundary, size)...] ''' # used only when redundant_partition is given def check_inside(query, partition_boundary): num_dims = len(query)//2 for i in range(num_dims): if query[i] >= partition_boundary[i] and query[num_dims + i] <= partition_boundary[num_dims + i]: pass else: return False return True block_ids = self.__find_overlapped_partition(self.pt_root, query, print_info) return block_ids def get_queryset_cost(self, queries): ''' return the cost array directly ''' costs = [] for query in queries: overlapped_leaf_ids = self.query_single(query) cost = 0 for nid in overlapped_leaf_ids: cost += self.nid_node_dict[nid].node_size costs.append(cost) return costs def evaluate_query_cost(self, queries, print_result = False): # if len(queries)==0: return 0 ''' get the logical IOs of the queris return the average query cost ''' total_cost = 0 total_ratio = 0 case = 0 total_overlap_ids = {} case_cost = {} tablesize=self.pt_root.node_size column_width=self.column_width accessed_cols=self.used_columns row_width=sum([column_width[col] for col in column_width]) access_width=sum([column_width[col] for col in accessed_cols]) for count,query in enumerate(queries): cost = 0 overlapped_leaf_ids = self.query_single(query) print(f"~~~~~~~~~Query#{count}: {overlapped_leaf_ids}") total_overlap_ids[case] = overlapped_leaf_ids actual_data_size=[] for nid in overlapped_leaf_ids: cur_node=self.nid_node_dict[nid] cost += cur_node.node_size actual_data_size.append(cur_node.node_size) print(f"query #{count}: Row Count: {sum(actual_data_size)}, Access Ratio:{sum(actual_data_size)*access_width/(tablesize*row_width)}") total_ratio += sum(actual_data_size)*access_width/(tablesize*row_width) total_cost += cost case_cost[case] = cost case += 1 print("Total logical IOs:", total_cost) print("Average logical IOs:", total_cost // len(queries)) return total_ratio / len(queries) def get_pid_for_data_point(self, point): ''' get the corresponding leaf partition nid for a data point point: [dim1_value, dim2_value...], contains the same dimenions as the partition tree ''' return self.__find_resided_partition(self.pt_root, point) def add_node(self, parent_id, child_node): child_node.nid = self.node_count self.node_count += 1 child_node.pid = parent_id self.nid_node_dict[child_node.nid] = child_node child_node.depth = self.nid_node_dict[parent_id].depth + 1 self.nid_node_dict[parent_id].children_ids.append(child_node.nid) self.nid_node_dict[parent_id].num_children += 1 self.nid_node_dict[parent_id].is_leaf = False def apply_split(self, parent_nid, split_dim, split_value, split_type = 0, extended_bound = None, approximate = False, pretend = False): ''' split_type = 0: split a node into 2 sub-nodes by a given dimension and value, distribute dataset split_type = 1: split a node by bounding split (will create an irregular shape partition) split_type = 2: split a node by daul-bounding split (will create an irregular shape partition) split_type = 3: split a node by var-bounding split (multi MBRs), distribute dataset extended_bound is only used in split type 1 approximate: used for measure query result size pretend: if pretend is True, return the split result, but do not apply this split ''' parent_node = self.nid_node_dict[parent_nid] if pretend: parent_node = copy.deepcopy(self.nid_node_dict[parent_nid]) child_node1, child_node2 = None, None def special_copy(target_node): new_node = copy.copy(target_node) new_node.boundary = copy.deepcopy(target_node.boundary) # new_node.children_ids = copy.deepcopy(target_node.children_ids) new_node.dataset = None new_node.queryset = None return new_node if split_type == 0: # create sub nodes child_node1 = special_copy(parent_node) child_node1.boundary[split_dim + child_node1.num_dims] = split_value child_node1.children_ids = [] child_node2 = special_copy(parent_node) child_node2.boundary[split_dim] = split_value child_node2.children_ids = [] # if parent_node.dataset != None: # The truth value of an array with more than one element is ambiguous. # https://stackoverflow.com/questions/36783921/valueerror-when-checking-if-variable-is-none-or-numpy-array if parent_node.dataset is not None: child_node1.dataset = parent_node.dataset[parent_node.dataset[:,split_dim] < split_value] child_node1.node_size = len(child_node1.dataset) child_node2.dataset = parent_node.dataset[parent_node.dataset[:,split_dim] >= split_value] child_node2.node_size = len(child_node2.dataset) if parent_node.queryset is not None: left_part, right_part, mid_part = parent_node.split_queryset(split_dim, split_value) child_node1.queryset = left_part + mid_part child_node2.queryset = right_part + mid_part # update current node if not pretend: self.add_node(parent_nid, child_node1) self.add_node(parent_nid, child_node2) self.nid_node_dict[parent_nid].split_type = "candidate cut" else: print("Invalid Split Type!") # real split if not pretend: del self.nid_node_dict[parent_nid].dataset del self.nid_node_dict[parent_nid].queryset return child_node1, child_node2 def get_leaves(self, use_partitionable = False): nodes = [] if use_partitionable: for nid, node in self.nid_node_dict.items(): if node.is_leaf and node.partitionable: nodes.append(node) else: for nid, node in self.nid_node_dict.items(): if node.is_leaf: nodes.append(node) return nodes def visualize(self, dims = [0, 1], queries = [], path = None, focus_region = None, add_text = True, use_sci = False): ''' visualize the partition tree's leaf nodes focus_region: in the shape of boundary ''' if len(dims) == 2: self.__visualize_2d(dims, queries, path, focus_region, add_text, use_sci) else: self.__visualize_3d(dims[0:3], queries, path, focus_region) # = = = = = internal functions = = = = = def __extract_sub_dataset(self, dataset, query): constraints = [] num_dims = self.pt_root.num_dims for d in range(num_dims): constraint_L = dataset[:,d] >= query[d] constraint_U = dataset[:,d] <= query[num_dims + d] constraints.append(constraint_L) constraints.append(constraint_U) constraint = np.all(constraints, axis=0) sub_dataset = dataset[constraint] return sub_dataset def __generate_node_list(self, node): ''' recursively add childrens into the list ''' node_list = [node] for nid in node.children_ids: node_list += self.__generate_node_list(self.nid_node_dict[nid]) return node_list def __serialize_by_pickle(self,node_list): serialized_node_list = [] for node in node_list: attributes = [node.num_dims,node.boundary,node.nid,node.pid] attributes.append(1 if node.is_irregular_shape_parent else 0) attributes.append(1 if node.is_irregular_shape else 0) attributes.append(node.num_children) attributes.append(1 if node.is_leaf else 0) attributes.append(node.node_size) attributes.append(node.depth) serialized_node_list.append(attributes) return serialized_node_list def __build_tree_from_serialized_node_list(self, serialized_node_list): self.pt_root = None self.nid_node_dict.clear() pid_children_ids_dict = {} for serialized_node in serialized_node_list: num_dims = serialized_node[0] boundary = serialized_node[1] nid = serialized_node[2] pid = serialized_node[3] is_irregular_shape_parent = False if serialized_node[4] == 0 else True is_irregular_shape = False if serialized_node[5] == 0 else True num_children = serialized_node[6] is_leaf = False if serialized_node[7] == 0 else True node_size = serialized_node[8] node = PartitionNode(num_dims, boundary, nid, pid, is_irregular_shape_parent, is_irregular_shape, num_children, [], is_leaf, node_size) node.depth=serialized_node[9] self.nid_node_dict[nid] = node # update dict if node.pid in pid_children_ids_dict: pid_children_ids_dict[node.pid].append(node.nid) else: pid_children_ids_dict[node.pid] = [node.nid] # make sure the irregular shape partition is placed at the end of the child list for pid, children_ids in pid_children_ids_dict.items(): if pid == -1: continue if self.nid_node_dict[pid].is_irregular_shape_parent and not self.nid_node_dict[children_ids[-1]].is_irregular_shape: # search for the irregular shape partition new_children_ids = [] irregular_shape_id = None for nid in children_ids: if self.nid_node_dict[nid].is_irregular_shape: irregular_shape_id = nid else: new_children_ids.append(nid) new_children_ids.append(irregular_shape_id) self.nid_node_dict[pid].children_ids = new_children_ids else: self.nid_node_dict[pid].children_ids = children_ids self.pt_root = self.nid_node_dict[0] def __bound_query_by_boundary(self, query, boundary): ''' bound the query by a node's boundary ''' bounded_query = query.copy() num_dims = self.pt_root.num_dims for dim in range(num_dims): bounded_query[dim] = max(query[dim], boundary[dim]) bounded_query[num_dims+dim] = min(query[num_dims+dim], boundary[num_dims+dim]) return bounded_query def __find_resided_partition(self, node, point): ''' for data point only ''' #print("enter function!") if node.is_leaf: #print("within leaf",node.nid) if node.is_contain(point): if node.linked_ids: write_ids=[node.nid] for link_id in node.linked_ids: link_node = self.nid_node_dict[link_id] if link_node.is_redundant_contain(point): write_ids.append(link_id) return write_ids return [node.nid] for nid in node.children_ids: if self.nid_node_dict[nid].is_contain(point): #print("within child", nid, "of parent",node.nid) return self.__find_resided_partition(self.nid_node_dict[nid], point) #print("no children of node",node.nid,"contains point") return [-1] def __find_overlapped_partition_consider_depth(self, node, query,depth): if node.is_leaf or depth==self.join_depth: return [node.nid] if node.is_overlap(query) > 0 else [] else: node_id_list = [] for nid in node.children_ids: node_id_list += self.__find_overlapped_partition_consider_depth(self.nid_node_dict[nid], query,depth+1) return node_id_list def __find_overlapped_partition(self, node, query, print_info = False): if print_info: print("Enter node", node.nid) if node.is_leaf: if print_info: print("node", node.nid, "is leaf") if print_info and node.is_overlap(query) > 0: print("node", node.nid, "is added as result") if node.is_overlap(query) > 0: return [node.nid] else: return [] node_id_list = [] if node.is_overlap(query) > 0: if print_info: print("searching childrens for node", node.nid) for nid in node.children_ids: node_id_list += self.__find_overlapped_partition(self.nid_node_dict[nid], query, print_info) return list(set(node_id_list)) def __visualize_2d(self, dims, queries = [], path = None, focus_region = None, add_text = True, use_sci = False): ''' focus_region: in the shape of boundary ''' fig, ax = plt.subplots(1) num_dims = self.pt_root.num_dims plt.xlim(self.pt_root.boundary[dims[0]], self.pt_root.boundary[dims[0]+num_dims]) plt.ylim(self.pt_root.boundary[dims[1]], self.pt_root.boundary[dims[1]+num_dims]) case = 0 for query in queries: lower1 = query[dims[0]] lower2 = query[dims[1]] upper1 = query[dims[0] + num_dims] upper2 = query[dims[1] + num_dims] rect = Rectangle((lower1, lower2), upper1 - lower1, upper2 - lower2, fill=False, edgecolor='r', linewidth=1) if add_text: ax.text(upper1, upper2, case, color='b', fontsize=7) case += 1 ax.add_patch(rect) leaves = self.get_leaves() for leaf in leaves: lower1 = leaf.boundary[dims[0]] lower2 = leaf.boundary[dims[1]] upper1 = leaf.boundary[dims[0]+num_dims] upper2 = leaf.boundary[dims[1]+num_dims] rect = Rectangle((lower1,lower2),upper1-lower1,upper2-lower2,fill=False,edgecolor='black',linewidth=1) if add_text: ax.text(lower1, lower2, leaf.nid, fontsize=7) ax.add_patch(rect) # if self.pt_root.query_MBRs: # for MBR in self.pt_root.query_MBRs: # lower1 = MBR.boundary[dims[0]] # lower2 = MBR.boundary[dims[1]] # upper1 = MBR.boundary[dims[0]+num_dims] # upper2 = MBR.boundary[dims[1]+num_dims] # # rect = Rectangle((lower1,lower2),upper1-lower1,upper2-lower2,fill=False,edgecolor='y',linewidth=1) # ax.add_patch(rect) ax.set_xlabel('dimension 1', fontsize=15) ax.set_ylabel('dimension 2', fontsize=15) if use_sci: plt.ticklabel_format(axis="both", style="sci", scilimits=(0,0)) #plt.xticks(np.arange(0, 400001, 100000), fontsize=10) #plt.yticks(np.arange(0, 20001, 5000), fontsize=10) plt.tight_layout() # preventing clipping the labels when save to pdf if focus_region is not None: # reform focus region into interleaf format formated_focus_region = [] for i in range(2): formated_focus_region.append(focus_region[i]) formated_focus_region.append(focus_region[2+i]) plt.axis(formated_focus_region) path=f'/home/liupengju/pycharmProjects/NORA_JOIN_SIMULATION/images/{self.name}.png' if path is not None: fig.savefig(path) plt.show() # %matplotlib notebook def __visualize_3d(self, dims, queries = [], path = None, focus_region = None): fig = plt.figure() ax = Axes3D(fig) num_dims = self.pt_root.num_dims plt.xlim(self.pt_root.boundary[dims[0]], self.pt_root.boundary[dims[0]+num_dims]) plt.ylim(self.pt_root.boundary[dims[1]], self.pt_root.boundary[dims[1]+num_dims]) ax.set_zlim(self.pt_root.boundary[dims[2]], self.pt_root.boundary[dims[2]+num_dims]) leaves = self.get_leaves() for leaf in leaves: L1 = leaf.boundary[dims[0]] L2 = leaf.boundary[dims[1]] L3 = leaf.boundary[dims[2]] U1 = leaf.boundary[dims[0]+num_dims] U2 = leaf.boundary[dims[1]+num_dims] U3 = leaf.boundary[dims[2]+num_dims] # the 12 lines to form a rectangle x = [L1, U1] y = [L2, L2] z = [L3, L3] ax.plot3D(x,y,z,color="g") y = [U2, U2] ax.plot3D(x,y,z,color="g") z = [U3, U3] ax.plot3D(x,y,z,color="g") y = [L2, L2] ax.plot3D(x,y,z,color="g") x = [L1, L1] y = [L2, U2] z = [L3, L3] ax.plot3D(x,y,z,color="g") x = [U1, U1] ax.plot3D(x,y,z,color="g") z = [U3, U3] ax.plot3D(x,y,z,color="g") x = [L1, L1] ax.plot3D(x,y,z,color="g") x = [L1, L1] y = [L2, L2] z = [L3, U3] ax.plot3D(x,y,z,color="g") x = [U1, U1] ax.plot3D(x,y,z,color="g") y = [U2, U2] ax.plot3D(x,y,z,color="g") x = [L1, L1] ax.plot3D(x,y,z,color="g") for query in queries: L1 = query[dims[0]] L2 = query[dims[1]] L3 = query[dims[2]] U1 = query[dims[0]+num_dims] U2 = query[dims[1]+num_dims] U3 = query[dims[2]+num_dims] # the 12 lines to form a rectangle x = [L1, U1] y = [L2, L2] z = [L3, L3] ax.plot3D(x,y,z,color="r") y = [U2, U2] ax.plot3D(x,y,z,color="r") z = [U3, U3] ax.plot3D(x,y,z,color="r") y = [L2, L2] ax.plot3D(x,y,z,color="r") x = [L1, L1] y = [L2, U2] z = [L3, L3] ax.plot3D(x,y,z,color="r") x = [U1, U1] ax.plot3D(x,y,z,color="r") z = [U3, U3] ax.plot3D(x,y,z,color="r") x = [L1, L1] ax.plot3D(x,y,z,color="r") x = [L1, L1] y = [L2, L2] z = [L3, U3] ax.plot3D(x,y,z,color="r") x = [U1, U1] ax.plot3D(x,y,z,color="r") y = [U2, U2] ax.plot3D(x,y,z,color="r") x = [L1, L1] ax.plot3D(x,y,z,color="r") # path=f'/home/liupengju/pycharmProjects/NORA_JOIN_SIMULATION/images/3d/{time.time()}.png' if path is not None: fig.savefig(path) plt.show()