ValueNet4SPARQL / src / spider / example_builder.py
example_builder.py
Raw
import copy

import numpy as np
from nltk import WordNetLemmatizer

from intermediate_representation import lf
from preprocessing.pre_process import lemmatize_list
from spider.example import Example

# Take care, this imports are necessary due to the dynamic "eval()" command further down
from intermediate_representation.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1, V

wordnet_lemmatizer = WordNetLemmatizer()


def build_column_matches_array(column_matches):
    """
    Here we map column matches (so the number of times a column got "hit" in the question) a simple 4d array.
    There might be smarter approaches to think about (e.g. embeddings?)
    """
    column_matches_array = np.zeros((len(column_matches), 4))

    for idx, match in enumerate(column_matches):
        column_matches_array[idx][0] = match['partial_column_match']
        column_matches_array[idx][1] = 5 if match['full_column_match'] else 0
        column_matches_array[idx][2] = 5 if match['full_value_match'] else 0
        # this actually happens never, we might remove it!
        column_matches_array[idx][3] = match['partial_value_match']

    return column_matches_array


def build_example(sql, all_schemas):
    # Schema contains detailed information about that database, especially tables & columns, as well as PK/FK relations.
    schema = all_schemas[sql['db_id']]

    column_matches = build_column_matches_array(sql['column_matches'])

    column_names, column_set, column_table_dict, columns_per_table, table_names = build_table_column_mappings(sql, schema)

    actions = _instantiate_actions(column_table_dict, sql)

    example = Example(
        question_tokens=[[token] for token in sql['question_toks']],
        n_columns=len(column_set),
        column_tokens=column_set,
        sql=sql['query'],
        column_matches=column_matches,
        tables=table_names,
        n_tables=len(table_names),
        column_table_dict=column_table_dict,
        columns=column_names,
        columns_per_table=columns_per_table,
        semql_actions=actions,
        values=sql['values']
    )
    example.sql_json = copy.deepcopy(sql)

    return example


def build_table_column_mappings(sql, table):
    _, table_names = lemmatize_list(table['table_names'])

    column_names = [col[1] for col in table['column_names']]
    column_table_indices = [col[0] for col in table['column_names']]

    _, column_set = lemmatize_list(sql['col_set'])
    _, columns = lemmatize_list(column_names)

    # this dict is telling for each column in what table it appears. So the key is the idx of the column, the values the idx of the tables.
    # example: the key 0 will appear in all tables (e.g. [1,2,3,4,5]) as it's the special column "*".
    # Most others will only appear one, but the id's (which are used as primary key / foreign key) will appear also multiple times.
    column_table_dict = _get_column_table_dict(column_names, column_table_indices, sql)

    # a simple list with sublists for each table, containing all the columns in that table.
    columns_per_table = _get_column_sublists_per_table(column_table_indices, columns)

    # this field contains the special column "*", referring to all columns in all tables. Not sure yet why we replace it with this special content.
    column_set[0] = ['count', 'number', 'many']

    return column_names, column_set, column_table_dict, columns_per_table, table_names


def _instantiate_actions(column_table_dict, sql):
    """
    in the pre-processing (see sql2SemQL.py) we parse the sql for each example to the SemQL-AST language. We then serialize it to a string.
    Here we do the opposite: we deserialize the SemQL-Query by dynamically create the right objects based on the string.
    """

    rule_label = None
    if 'rule_label' in sql:
        # Example: eval("Root1(3)") will dynamically create an instance of class Root1 with the constructor argument 3.
        rule_label = [eval(x) for x in sql['rule_label'].strip().split(' ')]

        if is_valid(rule_label, col_table_dict=column_table_dict, sql=sql) is False:
            raise RuntimeError("Invalid rule_label: {}. We don't use this sample".format(sql['rule_label']))
    return rule_label


def _get_column_table_dict(tab_cols, tab_ids, sql):
    table_dict = {}
    for c_id, c_v in enumerate(sql['col_set']):
        for cor_id, cor_val in enumerate(tab_cols):
            if c_v == cor_val:
                table_dict[tab_ids[cor_id]] = table_dict.get(tab_ids[cor_id], []) + [c_id]

    col_table_dict = {}
    for key_item, value_item in table_dict.items():
        for value in value_item:
            col_table_dict[value] = col_table_dict.get(value, []) + [key_item]
    col_table_dict[0] = [x for x in range(len(table_dict) - 1)]
    return col_table_dict


#Not really sure what is happening here... why this function is necessary??

def _get_column_sublists_per_table(tab_ids, tab_cols):
    table_col_dict = {}
    for ci, cv in zip(tab_ids, tab_cols):
        if ci != -1:
            table_col_dict[ci] = table_col_dict.get(ci, []) + cv
    result = []
    for ci in range(len(table_col_dict)):
        result.append(table_col_dict[ci])
    return result


def is_valid(rule_label, col_table_dict, sql):
    try:
        lf.build_tree(copy.copy(rule_label))  # here we rebuild a tree from the ist with rules
    except:
        print(rule_label)

    flag = False
    for r_id, rule in enumerate(rule_label):
        if type(rule) == C:
            try:
                assert rule_label[r_id + 1].id_c in col_table_dict[rule.id_c], print(sql['question'])
            except:
                flag = True
                print(sql['question'])
    return flag is False