import ast import os import astor from _ast import (withitem, alias, keyword, arg, arguments, ExceptHandler, comprehension, NotIn, NotEq, LtE, Lt, IsNot, Is, In, GtE, Gt, Eq, USub, UAdd, Not, Invert, Sub, RShift, Pow, MatMult, Mult, Mod, LShift, FloorDiv, Div, BitXor, BitOr, BitAnd, Add, Or, And, Store, Load, Del, Tuple, List, Name, Starred, Subscript, Attribute, NamedExpr, Constant, JoinedStr, FormattedValue, Call, Compare, YieldFrom, Yield, Await, GeneratorExp, DictComp, SetComp, ListComp, Set, Dict, IfExp, Lambda, UnaryOp, BinOp, BoolOp, Slice, Continue, Break, Pass, Expr, Nonlocal, Global, ImportFrom, Import, Assert, Try, Raise, AsyncWith, With, If, While, AsyncFor, For, AnnAssign, AugAssign, Assign, Delete, Return, ClassDef, AsyncFunctionDef, FunctionDef, Expression, Interactive, AST) from typing import Any, cast from collections import deque import pandas as pd import src.Calls as Calls from src.datatypes import GraphInformation from src.Calls import File, pd_dataframe, packages from src.util import is_file, ControlFlow, format_node_text, get_package from src.ast_package import AstPackage, get_ast_package from src.ast_package.types import CallComponents, CallArgumentsComponents, AssignComponents, BinOpComponents, \ AttributeComponents def insert_parameter(parameters: dict, is_block: bool, parameter: str, value): if parameter is None: return if is_block: parameters[str(parameter)].append(value) else: parameters[str(parameter)] = value class NodeVisitor(ast.NodeVisitor): graph_info: GraphInformation # TODO: REWORK VARIABLE NAMING def __init__(self, graph_information: GraphInformation = None): self.graph_info = graph_information self.columns = [] self.files = {} self.variables = {} self.alias = {} self.packages = {} self.subgraph = {} self.subgraph_node = {} self.control_flow = deque() self.user_defined_class = [] self.working_file = {} self.target_node = None self.data_flow_container = {} self.library_path = {} # Save library path that can't be extrapolated. i.e. train_test_split def visit(self, node: AST) -> Any: if type(node) in [ast.Assign, ast.Import, ast.ImportFrom, ast.Expr, ast.For, ast.If, ast.FunctionDef, ast.AugAssign, ast.Call, ast.Return, ast.Attribute]: # TODO: Improve this self.columns.clear() self.graph_info.add_node(format_node_text(node)) if len(self.control_flow) > 0: self.graph_info.add_control_flow(self.control_flow) return super().visit(node) def visit_Interactive(self, node: Interactive) -> Any: pass def visit_Expression(self, node: Expression) -> Any: pass def visit_FunctionDef(self, node: FunctionDef) -> Any: subgraph = SubGraph() subgraph.graph_info = self.graph_info args = self.visit_arguments(node.args) subgraph.arguments = {i: a for a in args for i in range(len(args))} subgraph.working_file = self.working_file self.subgraph[node.name] = subgraph self.subgraph_node[node.name] = node def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> Any: pass def visit_ClassDef(self, node: ClassDef) -> Any: pass def visit_Return(self, node: Return) -> Any: pass def visit_Delete(self, node: Delete) -> Any: pass def visit_Assign(self, node: Assign) -> Any: assign_components = AssignComponents() value_package = get_ast_package(node.value) value_package.extract_assign_value(self, node, assign_components) variable = None # TODO: Choose better name for target in node.targets: target_package = get_ast_package(target) variable = target_package.analyze_assign_target(self, target, assign_components) if variable is not None and not isinstance(variable, list): self._connect_node_to_column(self.files.get(variable)) def visit_AugAssign(self, node: AugAssign) -> Any: pass def visit_AnnAssign(self, node: AnnAssign) -> Any: pass def visit_For(self, node: For) -> Any: if len(self.control_flow) == 0: self.graph_info.add_control_flow([ControlFlow.LOOP]) fsg = ForSubGraph() fsg.graph_info = self.graph_info fsg.working_file = self.working_file fsg.files = self.files.copy() fsg.variables = self.variables.copy() fsg.packages = self.packages fsg.alias = self.alias fsg.subgraph = self.subgraph fsg.subgraph_node = self.subgraph_node fsg.control_flow = self.control_flow.copy() fsg.control_flow.append(ControlFlow.LOOP) fsg.data_flow_container = self.data_flow_container.copy() fsg.visit(node) def visit_AsyncFor(self, node: AsyncFor) -> Any: pass def visit_While(self, node: While) -> Any: pass def visit_If(self, node: If) -> Any: if len(self.control_flow) == 0: self.graph_info.add_control_flow([ControlFlow.CONDITIONAL]) self.control_flow.append(ControlFlow.CONDITIONAL) for el in node.body: self.visit(el) for el in node.orelse: self.visit(el) self.control_flow.pop() def visit_With(self, node: With) -> Any: pass def visit_AsyncWith(self, node: AsyncWith) -> Any: pass def visit_Raise(self, node: Raise) -> Any: pass def visit_Try(self, node: Try) -> Any: pass def visit_Assert(self, node: Assert) -> Any: pass def visit_Import(self, node: Import) -> Any: self.graph_info.add_control_flow([ControlFlow.IMPORT]) for import_name in node.names: name, as_name = self.visit_alias(import_name) if as_name is not None: self.alias[as_name] = name self.graph_info.add_import_node(name) def visit_ImportFrom(self, node: ImportFrom) -> Any: self.graph_info.add_control_flow([ControlFlow.IMPORT]) path = node.module for lib_name in node.names: name, as_name = self.visit_alias(lib_name) if as_name is not None: self.alias[as_name] = name self.graph_info.add_import_from_node(path, name) self.library_path[name] = f'{path}.{name}' def visit_Global(self, node: Global) -> Any: pass def visit_Nonlocal(self, node: Nonlocal) -> Any: pass def visit_Expr(self, node: Expr) -> Any: if isinstance(node.value, ast.Call): self.visit_Call(node.value) elif isinstance(node.value, ast.Constant): pass else: print("EXPR VALUE:", node.__dict__) print(astor.to_source(node)) def visit_Pass(self, node: Pass) -> Any: pass def visit_Break(self, node: Break) -> Any: pass def visit_Continue(self, node: Continue) -> Any: pass def visit_Slice(self, node: Slice) -> Any: lower = upper = step = None if isinstance(node.lower, ast.Constant): lower = self.visit_Constant(node.lower) return lower, upper, step def visit_BoolOp(self, node: BoolOp) -> Any: pass def visit_BinOp(self, node: BinOp) -> Any: components = BinOpComponents() left_ast_package = get_ast_package(node.left) left_ast_package.analyze_bin_op_branch(self, node, 'left', components) right_ast_package = get_ast_package(node.right) right_ast_package.analyze_bin_op_branch(self, node, 'right', components) left_package = get_package(components.left, self.alias) right_package = get_package(components.right, self.alias) if left_package is not None and right_package is not None: if left_package == pd_dataframe or right_package == pd_dataframe: return f'{pd_dataframe.library_path}.{pd_dataframe.name}' return astor.to_source(node) elif left_package is not None and len(left_package.return_types) > 0: pack = left_package.return_types[0] return f"{pack.library_path}.{pack.name}" elif right_package is not None and len(right_package.return_types) > 0: pack = right_package.return_types[0] return f"{pack.library_path}.{pack.name}" else: return format_node_text(node) def visit_UnaryOp(self, node: UnaryOp) -> Any: pass def visit_Lambda(self, node: Lambda) -> Any: pass def visit_IfExp(self, node: IfExp) -> Any: pass def visit_Dict(self, node: Dict) -> Any: keys = [] values = [] for key in node.keys: if isinstance(key, ast.Constant): keys.append(self.visit_Constant(key)) else: print("DICT KEY:", node) print(astor.to_source(node)) for value in node.values: if isinstance(value, ast.Constant): values.append(self.visit_Constant(value)) elif isinstance(value, ast.Name): values.append(self.visit_Name(value)) elif isinstance(value, ast.Attribute): attr_value, _, _ = self.visit_Attribute(value) values.append(attr_value) elif isinstance(value, ast.List): values.append(self.visit_List(value)) elif isinstance(value, ast.Call): _, _, _, base = self.visit_Call(value) values.append(base) else: print("DICT VALUE:", node.values) print(astor.to_source(node)) return {a: b for a, b in tuple(zip(keys, values))} def visit_Set(self, node: Set) -> Any: pass def visit_ListComp(self, node: ListComp) -> Any: pass def visit_SetComp(self, node: SetComp) -> Any: pass def visit_DictComp(self, node: DictComp) -> Any: pass def visit_GeneratorExp(self, node: GeneratorExp) -> Any: pass def visit_Await(self, node: Await) -> Any: pass def visit_Yield(self, node: Yield) -> Any: pass def visit_YieldFrom(self, node: YieldFrom) -> Any: pass def visit_Compare(self, node: Compare) -> Any: if isinstance(node.left, ast.Subscript): self.visit_Subscript(node.left) elif isinstance(node.left, ast.Call): self.visit_Call(node.left) else: print("COMPARE LEFT:", node.__dict__) print(astor.to_source(node)) for comparator in node.comparators: if isinstance(comparator, ast.Call): self.visit_Call(comparator) def _connect_node_to_column(self, file): if file is None: return self.graph_info.add_columns(file.filename, self.columns) self.columns.clear() def _extract_parent_package_information(self, components: CallComponents): if components.package is None: return if isinstance(components.package, list): return components.extract_parent_library() pkg = self.variables.get(components.parent_library) components.rewrite_library_path(pkg) self._connect_node_to_column(self.files.get(components.parent_library)) def visit_Call(self, node: Call) -> Any: func_package: AstPackage func_package = get_ast_package(node.func) call_components = CallComponents() func_package.extract_func(self, node, call_components) self._extract_parent_package_information(call_components) package_class = self._get_package_info(call_components.package) parameters = package_class.parameters.copy() args_components = CallArgumentsComponents(package_class.parameters.keys()) for i in range(len(node.args)): if not args_components.is_block: args_components.next_label() args_components.set_is_block(parameters) args_package = get_ast_package(node.args[i]) parameter_value = args_package.analyze_call_arguments(self, node, args_components, call_components, i) if package_class.is_relevant: insert_parameter(parameters, args_components.is_block, args_components.label, parameter_value) for kw in node.keywords: if isinstance(kw, ast.keyword): edge, value = self.visit_keyword(kw) self._extract_dataflow(value) self._add_to_column(value, call_components.base_package) if package_class.is_relevant and edge is not None: parameters[edge] = str(value) self.graph_info.add_parameters(parameters) self._create_package_call(package_class, call_components.package) if call_components.base_package is not None: variable, *_ = call_components.base_package.split('.') self._connect_node_to_column(self.files.get(variable)) if type(call_components.package) != list and call_components.package in self.subgraph.keys(): return_type = self._subgraph_logic(call_components.package, args_components.file_args, args_components.call_args) return return_type, call_components.file, call_components.package, call_components.base_package elif package_class.is_relevant: return (package_class.return_types, call_components.file, call_components.package, call_components.base_package) return [], call_components.file, call_components.package, call_components.base_package def visit_FormattedValue(self, node: FormattedValue) -> Any: pass def visit_JoinedStr(self, node: JoinedStr) -> Any: pass def visit_Constant(self, node: Constant) -> Any: return node.value def visit_NamedExpr(self, node: NamedExpr) -> Any: pass def visit_Attribute(self, node: Attribute) -> Any: components = AttributeComponents() attribute_package = get_ast_package(node.value) attribute_package.analyze_attribute_value(self, node, components) return components.path, components.parent_path, components.file def visit_Subscript(self, node: Subscript) -> Any: value_package = get_ast_package(node.value) name, base = value_package.extract_subscript_value(self, node) if not isinstance(name, list): var = self.variables.get(name, self.variables.get(base)) file = self.files.get(name, self.files.get(base)) if var is pd_dataframe and file is not None: working_file = self.working_file.get(file.filename, pd.DataFrame()) if isinstance(node.slice, ast.Index): index = self.visit_Index(node.slice) column_list = list(working_file) if not isinstance(index, list): index = self.variables.get(index, index) if isinstance(index, int): if index < len(column_list): column = column_list[index] self.columns.append(column) elif isinstance(index, str) and index in column_list: self.columns.append(index) elif isinstance(index, list): for value in index: if value in column_list: self.columns.append(value) elif isinstance(node.slice, ast.ExtSlice): self.visit_ExtSlice(node.slice) # TODO: HOW TO EXPRESS THIS VALUE? elif isinstance(node.slice, ast.Slice): lower, upper, step = self.visit_Slice(node.slice) columns = list(working_file) if lower is None: lower = 0 if upper is None: upper = len(columns) if step is None: step = 1 if isinstance(lower, int): for i in range(lower, upper, step): self.columns.append(columns[i]) elif isinstance(lower, str): is_checked = False for col in columns: if is_checked: self.columns.append(col) if col == upper: break else: if col == lower: self.columns.append(col) is_checked = True else: print('SUBSCRIPT SOMETHING', node.slice, node.__dict__) print(astor.to_source(node)) self._connect_node_to_column(file) return name, base def visit_Starred(self, node: Starred) -> Any: pass def visit_Name(self, node: Name) -> Any: return node.id def visit_List(self, node: List) -> Any: elements = [] for i in range(len(node.elts)): element_package = get_ast_package(node.elts[i]) element_package.extract_list_element(self, node, i, elements) return elements def visit_Tuple(self, node: Tuple) -> Any: elements = [] for element in node.elts: if isinstance(element, ast.Name): elements.append(self.visit_Name(element)) elif isinstance(element, ast.Constant): elements.append(self.visit_Constant(element)) elif isinstance(element, ast.Subscript): name, _ = self.visit_Subscript(element) elements.append(name) elif isinstance(element, ast.Call): _, _, _, base = self.visit_Call(element) elements.append(base) else: print("TUPLE ELTS:", node.__dict__) print(astor.to_source(node)) return elements def visit_Del(self, node: Del) -> Any: pass def visit_Load(self, node: Load) -> Any: pass def visit_Store(self, node: Store) -> Any: pass def visit_And(self, node: And) -> Any: pass def visit_Or(self, node: Or) -> Any: pass def visit_Add(self, node: Add) -> Any: pass def visit_BitAnd(self, node: BitAnd) -> Any: pass def visit_BitOr(self, node: BitOr) -> Any: pass def visit_BitXor(self, node: BitXor) -> Any: pass def visit_Div(self, node: Div) -> Any: pass def visit_FloorDiv(self, node: FloorDiv) -> Any: pass def visit_LShift(self, node: LShift) -> Any: pass def visit_Mod(self, node: Mod) -> Any: pass def visit_Mult(self, node: Mult) -> Any: pass def visit_MatMult(self, node: MatMult) -> Any: pass def visit_Pow(self, node: Pow) -> Any: pass def visit_RShift(self, node: RShift) -> Any: pass def visit_Sub(self, node: Sub) -> Any: pass def visit_Invert(self, node: Invert) -> Any: pass def visit_Not(self, node: Not) -> Any: pass def visit_UAdd(self, node: UAdd) -> Any: pass def visit_USub(self, node: USub) -> Any: pass def visit_Eq(self, node: Eq) -> Any: pass def visit_Gt(self, node: Gt) -> Any: pass def visit_GtE(self, node: GtE) -> Any: pass def visit_In(self, node: In) -> Any: pass def visit_Is(self, node: Is) -> Any: pass def visit_IsNot(self, node: IsNot) -> Any: pass def visit_Lt(self, node: Lt) -> Any: pass def visit_LtE(self, node: LtE) -> Any: pass def visit_NotEq(self, node: NotEq) -> Any: pass def visit_NotIn(self, node: NotIn) -> Any: pass def visit_comprehension(self, node: comprehension) -> Any: pass def visit_ExceptHandler(self, node: ExceptHandler) -> Any: pass def visit_arguments(self, node: arguments) -> Any: arguments_list = [] for argument in node.args: arguments_list.append(self.visit_arg(argument)) return arguments_list def visit_arg(self, node: arg) -> Any: return node.arg def visit_keyword(self, node: keyword) -> Any: keyword_package = get_ast_package(node.value) value = keyword_package.extract_keyword_value(self, node) return node.arg, value def visit_alias(self, node: alias) -> Any: return node.name, node.asname def visit_withitem(self, node: withitem) -> Any: pass def visit_ExtSlice(self, node: ast.ExtSlice) -> Any: for dim in node.dims: sl = None if isinstance(dim, ast.Slice): lower, upper, step = self.visit_Slice(dim) elif isinstance(dim, ast.Index): sl = self.visit_Index(dim) else: print("EXT_SLICE DIM:", node.__dict__) print(astor.to_source(node)) pass def visit_Index(self, node: ast.Index) -> Any: if isinstance(node.value, ast.Constant): return self.visit_Constant(node.value) elif isinstance(node.value, ast.BinOp): self.visit_BinOp(node.value) elif isinstance(node.value, ast.Compare): self.visit_Compare(node.value) elif isinstance(node.value, ast.List): return self.visit_List(node.value) elif isinstance(node.value, ast.Subscript): print(self.visit_Subscript(node.value)) # TODO: MAKE SOMETHING OF THIS VALUE elif isinstance(node.value, ast.Name): return self.visit_Name(node.value) else: print("INDEX VALUE:", node.__dict__) print(astor.to_source(node)) def visit_Suite(self, node: ast.Suite) -> Any: pass def visit_AugLoad(self, node: ast.AugLoad) -> Any: pass def visit_AugStore(self, node: ast.AugStore) -> Any: pass def visit_Param(self, node: ast.Param) -> Any: pass def visit_Num(self, node: ast.Num) -> Any: pass def visit_Str(self, node: ast.Str) -> Any: pass def visit_Bytes(self, node: ast.Bytes) -> Any: pass def visit_NameConstant(self, node: ast.NameConstant) -> Any: pass def visit_Ellipsis(self, node: Ellipsis) -> Any: pass def _add_to_column(self, column_name, table_name=None): file = self.files.get(table_name, File('')) file_name = file.filename working_file = self.working_file.get(file_name, pd.DataFrame()) if column_name in list(working_file): self.columns.append(column_name) return True return False def _subgraph_logic(self, package, file_args, call_args): s_graph = self.subgraph.get(package) s_graph.data_flow_container = self.data_flow_container.copy() for param in s_graph.arguments.values(): s_graph.data_flow_container[param] = self.graph_info.tail s_graph.files = {s_graph.arguments.get(file_k): self.files.get(file_args.get(file_k)) for file_k in file_args.keys()} s_graph.variables = {s_graph.arguments.get(arg_k): self.variables.get(call_args.get(arg_k)) for arg_k in call_args.keys()} s_graph.is_starting = True s_graph.visit(self.subgraph_node.get(package)) return s_graph.return_type def param_subgraph_init(self): psg = ParamSubGraph() psg.graph_info = self.graph_info psg.working_file = self.working_file psg.files = self.files psg.variables = self.variables psg.packages = self.packages psg.alias = self.alias psg.data_flow_container = self.data_flow_container psg.library_path = self.library_path return psg def _create_package_call(self, package_class, package): if package_class.is_relevant: self.graph_info.add_package_call(package_class.full_path()) elif isinstance(package, str) and package in ("len", "range", "list"): self.graph_info.add_built_in_call(package) def _file_creation(self, file: str) -> File: if is_file(file): _, filename = os.path.split(file) file = File(filename) self.graph_info.add_file(file) return file def _extract_dataflow(self, variable_name): if not isinstance(variable_name, str): return flow = self.data_flow_container.get(variable_name) if flow is not None: self.graph_info.add_data_flows(flow) def _create_library_path(self, package_name): if not package_name or isinstance(package_name, list): return '' if '.' in package_name: base, *rest = package_name.split('.') base = self.alias.get(base, base) return f'{base}.{".".join(rest)}' return self.library_path.get(package_name, package_name) def _get_package_info(self, package_name): library_path = self._create_library_path(package_name) return packages.get(library_path, Calls.Call(is_relevant=False)) class SubGraph(NodeVisitor): __slots__ = ['arguments', 'is_starting', 'return_type'] def __init__(self): super().__init__() self.is_starting = True self.return_type = None def visit(self, node: AST) -> Any: if self.is_starting: self.is_starting = False return self.visit_FunctionDef(cast(ast.FunctionDef, node)) return super().visit(node) def visit_FunctionDef(self, node: FunctionDef) -> Any: self.control_flow.append(ControlFlow.METHOD) for el in node.body: self.visit(el) self.control_flow.pop() def visit_Return(self, node: Return) -> Any: if isinstance(node.value, ast.Call): self.visit_Call(node.value) elif isinstance(node.value, ast.Name): name = self.visit_Name(node.value) var = self.variables.get(name) if var is not None and isinstance(var, Calls.Call): self.return_type = var.return_types else: print("SUBGRAPH RETURN VALUE:", node.__dict__) print(astor.to_source(node)) class ParamSubGraph(NodeVisitor): __slots__ = ['return_type', 'target_node'] def __init__(self): super().__init__() self.return_type = None def visit(self, node: AST) -> Any: n = super().visit(node) if self.target_node is None: self.graph_info.rewrite_node_flow() else: self.graph_info.insert_before(self.target_node) return n def visit_Call(self, node: Call) -> Any: # TODO: REMOVE UNUSED ELEMENT return_types, file, name, base = super().visit_Call(node) info = None # if package is not None: # *path, lib = package.split('.') # info = get_package(lib, '.'.join(path)) if info is not None: self.return_type = return_types[0] if len(return_types) > 0 else None self.return_type = return_types return return_types, file, name, base class ForSubGraph(NodeVisitor): def __init__(self): super().__init__() self.is_starting = True def visit(self, node: AST) -> Any: if self.is_starting: return self.visit_For(cast(ast.For, node)) return super().visit(node) def visit_For(self, node: For) -> Any: if self.is_starting: self.is_starting = False target = iter_value = None if isinstance(node.target, ast.Name): target = self.visit_Name(node.target) else: print("FOR TARGET", node.__dict__) print(astor.to_source(node)) if isinstance(node.iter, ast.Call): self.visit_Call(node.iter) elif isinstance(node.iter, ast.Name): iter_value = self.visit_Name(node.iter) elif isinstance(node.iter, ast.List): self.variables[target] = self.visit_List(node.iter) else: print("FOR ITER:", node.__dict__) print(astor.to_source(node)) if iter_value in self.variables.keys(): self.variables[target] = self.variables.get(iter_value) for el in node.body: self.visit(el) else: super().visit_For(node)