import time
# from sympy import im
import pickle
from model.partition_node import PartitionNode
import numpy as np
from numpy import genfromtxt
import copy
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from mpl_toolkits.mplot3d import Axes3D
class PartitionTree:
'''
The data structure that represent the partition layout, which also maintain the parent, children relation info
Designed to provide efficient online query and serialized ability
The node data structure could be checked from the PartitionNode class
'''
def __init__(self, num_dims = 0, boundary = []):
# the node id of root should be 0, its pid should be -1
# note this initialization does not need dataset and does not set node size!
self.pt_root = PartitionNode(num_dims, boundary, nid = 0, pid = -1, is_irregular_shape_parent = False,
is_irregular_shape = False, num_children = 0, children_ids = [], is_leaf = True, node_size = 0)
self.nid_node_dict = {0: self.pt_root} # node id to node dictionary
self.node_count = 1 # the root node
# self.allocations=list()
# self.table_dataset_size=None # record the table data size
# = = = = = public functions (API) = = = = =
def save_tree(self, path):
node_list = self.__generate_node_list(self.pt_root) # do we really need this step?
serialized_node_list = self.__serialize_by_pickle(node_list)
with open(path, "wb") as f:
# perform different operations for join tree and normal tree
if hasattr(self, 'join_attr'):
pickle.dump((self.join_attr,self.join_depth,self.column_width,self.used_columns,serialized_node_list), f, True)
else:
pickle.dump((self.column_width,self.used_columns,serialized_node_list), f, True)
# return serialized_node_list
def load_tree(self, path):
with open(path, "rb") as f:
res = pickle.load(f)
if hasattr(self, 'join_attr'):
self.join_attr=res[0]
self.join_depth=res[1]
self.column_width=res[2]
self.used_columns=res[3]
serialized_node_list=res[4]
else:
self.column_width=res[0]
self.used_columns=res[1]
serialized_node_list=res[2]
self.__build_tree_from_serialized_node_list(serialized_node_list)
self.node_count=len(self.nid_node_dict.keys())
# 仅仅访问分区树top layer的节点,确定查询的大范围(即join key的范围)
def query_single_toplayer(self, query):
partition_ids = self.__find_overlapped_partition_consider_depth(self.pt_root, query, 1)
return partition_ids
def query_single(self, query, print_info = False):
'''
query is in plain form, i.e., [l1,l2,...,ln, u1,u2,...,un]
return the overlapped leaf partitions ids!
redundant_partition: [(boundary, size)...]
'''
# used only when redundant_partition is given
def check_inside(query, partition_boundary):
num_dims = len(query)//2
for i in range(num_dims):
if query[i] >= partition_boundary[i] and query[num_dims + i] <= partition_boundary[num_dims + i]:
pass
else:
return False
return True
block_ids = self.__find_overlapped_partition(self.pt_root, query, print_info)
return block_ids
def get_queryset_cost(self, queries):
'''
return the cost array directly
'''
costs = []
for query in queries:
overlapped_leaf_ids = self.query_single(query)
cost = 0
for nid in overlapped_leaf_ids:
cost += self.nid_node_dict[nid].node_size
costs.append(cost)
return costs
def evaluate_query_cost(self, queries, print_result = False):
# if len(queries)==0: return 0
'''
get the logical IOs of the queris
return the average query cost
'''
total_cost = 0
total_ratio = 0
case = 0
total_overlap_ids = {}
case_cost = {}
tablesize=self.pt_root.node_size
column_width=self.column_width
accessed_cols=self.used_columns
row_width=sum([column_width[col] for col in column_width])
access_width=sum([column_width[col] for col in accessed_cols])
for count,query in enumerate(queries):
cost = 0
overlapped_leaf_ids = self.query_single(query)
print(f"~~~~~~~~~Query#{count}: {overlapped_leaf_ids}")
total_overlap_ids[case] = overlapped_leaf_ids
actual_data_size=[]
for nid in overlapped_leaf_ids:
cur_node=self.nid_node_dict[nid]
cost += cur_node.node_size
actual_data_size.append(cur_node.node_size)
print(f"query #{count}: Row Count: {sum(actual_data_size)}, Access Ratio:{sum(actual_data_size)*access_width/(tablesize*row_width)}")
total_ratio += sum(actual_data_size)*access_width/(tablesize*row_width)
total_cost += cost
case_cost[case] = cost
case += 1
print("Total logical IOs:", total_cost)
print("Average logical IOs:", total_cost // len(queries))
return total_ratio / len(queries)
def get_pid_for_data_point(self, point):
'''
get the corresponding leaf partition nid for a data point
point: [dim1_value, dim2_value...], contains the same dimenions as the partition tree
'''
return self.__find_resided_partition(self.pt_root, point)
def add_node(self, parent_id, child_node):
child_node.nid = self.node_count
self.node_count += 1
child_node.pid = parent_id
self.nid_node_dict[child_node.nid] = child_node
child_node.depth = self.nid_node_dict[parent_id].depth + 1
self.nid_node_dict[parent_id].children_ids.append(child_node.nid)
self.nid_node_dict[parent_id].num_children += 1
self.nid_node_dict[parent_id].is_leaf = False
def apply_split(self, parent_nid, split_dim, split_value, split_type = 0, extended_bound = None, approximate = False,
pretend = False):
'''
split_type = 0: split a node into 2 sub-nodes by a given dimension and value, distribute dataset
split_type = 1: split a node by bounding split (will create an irregular shape partition)
split_type = 2: split a node by daul-bounding split (will create an irregular shape partition)
split_type = 3: split a node by var-bounding split (multi MBRs), distribute dataset
extended_bound is only used in split type 1
approximate: used for measure query result size
pretend: if pretend is True, return the split result, but do not apply this split
'''
parent_node = self.nid_node_dict[parent_nid]
if pretend:
parent_node = copy.deepcopy(self.nid_node_dict[parent_nid])
child_node1, child_node2 = None, None
def special_copy(target_node):
new_node = copy.copy(target_node)
new_node.boundary = copy.deepcopy(target_node.boundary)
# new_node.children_ids = copy.deepcopy(target_node.children_ids)
new_node.dataset = None
new_node.queryset = None
return new_node
if split_type == 0:
# create sub nodes
child_node1 = special_copy(parent_node)
child_node1.boundary[split_dim + child_node1.num_dims] = split_value
child_node1.children_ids = []
child_node2 = special_copy(parent_node)
child_node2.boundary[split_dim] = split_value
child_node2.children_ids = []
# if parent_node.dataset != None: # The truth value of an array with more than one element is ambiguous.
# https://stackoverflow.com/questions/36783921/valueerror-when-checking-if-variable-is-none-or-numpy-array
if parent_node.dataset is not None:
child_node1.dataset = parent_node.dataset[parent_node.dataset[:,split_dim] < split_value]
child_node1.node_size = len(child_node1.dataset)
child_node2.dataset = parent_node.dataset[parent_node.dataset[:,split_dim] >= split_value]
child_node2.node_size = len(child_node2.dataset)
if parent_node.queryset is not None:
left_part, right_part, mid_part = parent_node.split_queryset(split_dim, split_value)
child_node1.queryset = left_part + mid_part
child_node2.queryset = right_part + mid_part
# update current node
if not pretend:
self.add_node(parent_nid, child_node1)
self.add_node(parent_nid, child_node2)
self.nid_node_dict[parent_nid].split_type = "candidate cut"
else:
print("Invalid Split Type!")
# real split
if not pretend:
del self.nid_node_dict[parent_nid].dataset
del self.nid_node_dict[parent_nid].queryset
return child_node1, child_node2
def get_leaves(self, use_partitionable = False):
nodes = []
if use_partitionable:
for nid, node in self.nid_node_dict.items():
if node.is_leaf and node.partitionable:
nodes.append(node)
else:
for nid, node in self.nid_node_dict.items():
if node.is_leaf:
nodes.append(node)
return nodes
def visualize(self, dims = [0, 1], queries = [], path = None, focus_region = None, add_text = True, use_sci = False):
'''
visualize the partition tree's leaf nodes
focus_region: in the shape of boundary
'''
if len(dims) == 2:
self.__visualize_2d(dims, queries, path, focus_region, add_text, use_sci)
else:
self.__visualize_3d(dims[0:3], queries, path, focus_region)
# = = = = = internal functions = = = = =
def __extract_sub_dataset(self, dataset, query):
constraints = []
num_dims = self.pt_root.num_dims
for d in range(num_dims):
constraint_L = dataset[:,d] >= query[d]
constraint_U = dataset[:,d] <= query[num_dims + d]
constraints.append(constraint_L)
constraints.append(constraint_U)
constraint = np.all(constraints, axis=0)
sub_dataset = dataset[constraint]
return sub_dataset
def __generate_node_list(self, node):
'''
recursively add childrens into the list
'''
node_list = [node]
for nid in node.children_ids:
node_list += self.__generate_node_list(self.nid_node_dict[nid])
return node_list
def __serialize_by_pickle(self,node_list):
serialized_node_list = []
for node in node_list:
attributes = [node.num_dims,node.boundary,node.nid,node.pid]
attributes.append(1 if node.is_irregular_shape_parent else 0)
attributes.append(1 if node.is_irregular_shape else 0)
attributes.append(node.num_children)
attributes.append(1 if node.is_leaf else 0)
attributes.append(node.node_size)
attributes.append(node.depth)
serialized_node_list.append(attributes)
return serialized_node_list
def __build_tree_from_serialized_node_list(self, serialized_node_list):
self.pt_root = None
self.nid_node_dict.clear()
pid_children_ids_dict = {}
for serialized_node in serialized_node_list:
num_dims = serialized_node[0]
boundary = serialized_node[1]
nid = serialized_node[2]
pid = serialized_node[3]
is_irregular_shape_parent = False if serialized_node[4] == 0 else True
is_irregular_shape = False if serialized_node[5] == 0 else True
num_children = serialized_node[6]
is_leaf = False if serialized_node[7] == 0 else True
node_size = serialized_node[8]
node = PartitionNode(num_dims, boundary, nid, pid, is_irregular_shape_parent,
is_irregular_shape, num_children, [], is_leaf, node_size)
node.depth=serialized_node[9]
self.nid_node_dict[nid] = node # update dict
if node.pid in pid_children_ids_dict:
pid_children_ids_dict[node.pid].append(node.nid)
else:
pid_children_ids_dict[node.pid] = [node.nid]
# make sure the irregular shape partition is placed at the end of the child list
for pid, children_ids in pid_children_ids_dict.items():
if pid == -1:
continue
if self.nid_node_dict[pid].is_irregular_shape_parent and not self.nid_node_dict[children_ids[-1]].is_irregular_shape:
# search for the irregular shape partition
new_children_ids = []
irregular_shape_id = None
for nid in children_ids:
if self.nid_node_dict[nid].is_irregular_shape:
irregular_shape_id = nid
else:
new_children_ids.append(nid)
new_children_ids.append(irregular_shape_id)
self.nid_node_dict[pid].children_ids = new_children_ids
else:
self.nid_node_dict[pid].children_ids = children_ids
self.pt_root = self.nid_node_dict[0]
def __bound_query_by_boundary(self, query, boundary):
'''
bound the query by a node's boundary
'''
bounded_query = query.copy()
num_dims = self.pt_root.num_dims
for dim in range(num_dims):
bounded_query[dim] = max(query[dim], boundary[dim])
bounded_query[num_dims+dim] = min(query[num_dims+dim], boundary[num_dims+dim])
return bounded_query
def __find_resided_partition(self, node, point):
'''
for data point only
'''
#print("enter function!")
if node.is_leaf:
#print("within leaf",node.nid)
if node.is_contain(point):
if node.linked_ids:
write_ids=[node.nid]
for link_id in node.linked_ids:
link_node = self.nid_node_dict[link_id]
if link_node.is_redundant_contain(point):
write_ids.append(link_id)
return write_ids
return [node.nid]
for nid in node.children_ids:
if self.nid_node_dict[nid].is_contain(point):
#print("within child", nid, "of parent",node.nid)
return self.__find_resided_partition(self.nid_node_dict[nid], point)
#print("no children of node",node.nid,"contains point")
return [-1]
def __find_overlapped_partition_consider_depth(self, node, query,depth):
if node.is_leaf or depth==self.join_depth:
return [node.nid] if node.is_overlap(query) > 0 else []
else:
node_id_list = []
for nid in node.children_ids:
node_id_list += self.__find_overlapped_partition_consider_depth(self.nid_node_dict[nid], query,depth+1)
return node_id_list
def __find_overlapped_partition(self, node, query, print_info = False):
if print_info:
print("Enter node", node.nid)
if node.is_leaf:
if print_info:
print("node", node.nid, "is leaf")
if print_info and node.is_overlap(query) > 0:
print("node", node.nid, "is added as result")
if node.is_overlap(query) > 0:
return [node.nid]
else:
return []
node_id_list = []
if node.is_overlap(query) > 0:
if print_info:
print("searching childrens for node", node.nid)
for nid in node.children_ids:
node_id_list += self.__find_overlapped_partition(self.nid_node_dict[nid], query, print_info)
return list(set(node_id_list))
def __visualize_2d(self, dims, queries = [], path = None, focus_region = None, add_text = True, use_sci = False):
'''
focus_region: in the shape of boundary
'''
fig, ax = plt.subplots(1)
num_dims = self.pt_root.num_dims
plt.xlim(self.pt_root.boundary[dims[0]], self.pt_root.boundary[dims[0]+num_dims])
plt.ylim(self.pt_root.boundary[dims[1]], self.pt_root.boundary[dims[1]+num_dims])
case = 0
for query in queries:
lower1 = query[dims[0]]
lower2 = query[dims[1]]
upper1 = query[dims[0] + num_dims]
upper2 = query[dims[1] + num_dims]
rect = Rectangle((lower1, lower2), upper1 - lower1, upper2 - lower2, fill=False, edgecolor='r',
linewidth=1)
if add_text:
ax.text(upper1, upper2, case, color='b', fontsize=7)
case += 1
ax.add_patch(rect)
leaves = self.get_leaves()
for leaf in leaves:
lower1 = leaf.boundary[dims[0]]
lower2 = leaf.boundary[dims[1]]
upper1 = leaf.boundary[dims[0]+num_dims]
upper2 = leaf.boundary[dims[1]+num_dims]
rect = Rectangle((lower1,lower2),upper1-lower1,upper2-lower2,fill=False,edgecolor='black',linewidth=1)
if add_text:
ax.text(lower1, lower2, leaf.nid, fontsize=7)
ax.add_patch(rect)
# if self.pt_root.query_MBRs:
# for MBR in self.pt_root.query_MBRs:
# lower1 = MBR.boundary[dims[0]]
# lower2 = MBR.boundary[dims[1]]
# upper1 = MBR.boundary[dims[0]+num_dims]
# upper2 = MBR.boundary[dims[1]+num_dims]
#
# rect = Rectangle((lower1,lower2),upper1-lower1,upper2-lower2,fill=False,edgecolor='y',linewidth=1)
# ax.add_patch(rect)
ax.set_xlabel('dimension 1', fontsize=15)
ax.set_ylabel('dimension 2', fontsize=15)
if use_sci:
plt.ticklabel_format(axis="both", style="sci", scilimits=(0,0))
#plt.xticks(np.arange(0, 400001, 100000), fontsize=10)
#plt.yticks(np.arange(0, 20001, 5000), fontsize=10)
plt.tight_layout() # preventing clipping the labels when save to pdf
if focus_region is not None:
# reform focus region into interleaf format
formated_focus_region = []
for i in range(2):
formated_focus_region.append(focus_region[i])
formated_focus_region.append(focus_region[2+i])
plt.axis(formated_focus_region)
path=f'/home/liupengju/pycharmProjects/NORA_JOIN_SIMULATION/images/{self.name}.png'
if path is not None:
fig.savefig(path)
plt.show()
# %matplotlib notebook
def __visualize_3d(self, dims, queries = [], path = None, focus_region = None):
fig = plt.figure()
ax = Axes3D(fig)
num_dims = self.pt_root.num_dims
plt.xlim(self.pt_root.boundary[dims[0]], self.pt_root.boundary[dims[0]+num_dims])
plt.ylim(self.pt_root.boundary[dims[1]], self.pt_root.boundary[dims[1]+num_dims])
ax.set_zlim(self.pt_root.boundary[dims[2]], self.pt_root.boundary[dims[2]+num_dims])
leaves = self.get_leaves()
for leaf in leaves:
L1 = leaf.boundary[dims[0]]
L2 = leaf.boundary[dims[1]]
L3 = leaf.boundary[dims[2]]
U1 = leaf.boundary[dims[0]+num_dims]
U2 = leaf.boundary[dims[1]+num_dims]
U3 = leaf.boundary[dims[2]+num_dims]
# the 12 lines to form a rectangle
x = [L1, U1]
y = [L2, L2]
z = [L3, L3]
ax.plot3D(x,y,z,color="g")
y = [U2, U2]
ax.plot3D(x,y,z,color="g")
z = [U3, U3]
ax.plot3D(x,y,z,color="g")
y = [L2, L2]
ax.plot3D(x,y,z,color="g")
x = [L1, L1]
y = [L2, U2]
z = [L3, L3]
ax.plot3D(x,y,z,color="g")
x = [U1, U1]
ax.plot3D(x,y,z,color="g")
z = [U3, U3]
ax.plot3D(x,y,z,color="g")
x = [L1, L1]
ax.plot3D(x,y,z,color="g")
x = [L1, L1]
y = [L2, L2]
z = [L3, U3]
ax.plot3D(x,y,z,color="g")
x = [U1, U1]
ax.plot3D(x,y,z,color="g")
y = [U2, U2]
ax.plot3D(x,y,z,color="g")
x = [L1, L1]
ax.plot3D(x,y,z,color="g")
for query in queries:
L1 = query[dims[0]]
L2 = query[dims[1]]
L3 = query[dims[2]]
U1 = query[dims[0]+num_dims]
U2 = query[dims[1]+num_dims]
U3 = query[dims[2]+num_dims]
# the 12 lines to form a rectangle
x = [L1, U1]
y = [L2, L2]
z = [L3, L3]
ax.plot3D(x,y,z,color="r")
y = [U2, U2]
ax.plot3D(x,y,z,color="r")
z = [U3, U3]
ax.plot3D(x,y,z,color="r")
y = [L2, L2]
ax.plot3D(x,y,z,color="r")
x = [L1, L1]
y = [L2, U2]
z = [L3, L3]
ax.plot3D(x,y,z,color="r")
x = [U1, U1]
ax.plot3D(x,y,z,color="r")
z = [U3, U3]
ax.plot3D(x,y,z,color="r")
x = [L1, L1]
ax.plot3D(x,y,z,color="r")
x = [L1, L1]
y = [L2, L2]
z = [L3, U3]
ax.plot3D(x,y,z,color="r")
x = [U1, U1]
ax.plot3D(x,y,z,color="r")
y = [U2, U2]
ax.plot3D(x,y,z,color="r")
x = [L1, L1]
ax.plot3D(x,y,z,color="r")
# path=f'/home/liupengju/pycharmProjects/NORA_JOIN_SIMULATION/images/3d/{time.time()}.png'
if path is not None:
fig.savefig(path)
plt.show()