ValueNet4SPARQL / src / intermediate_representation / sem2sparql / sem2SPARQL.py
sem2SPARQL.py
Raw
import argparse
import os
import traceback
import numbers
import random
import re
from intermediate_representation.graph import Graph
from intermediate_representation.sem2sparql.infer_from_clause import infer_from_clause
from intermediate_representation.semQL_SPARQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1, V
from intermediate_representation.sem_utils import alter_column0



def split_logical_form(lf):
    indexs = [i + 1 for i, letter in enumerate(lf) if letter == ')']
    indexs.insert(0, 0)
    components = list()
    for i in range(1, len(indexs)):
        components.append(lf[indexs[i - 1]:indexs[i]].strip())
    return components


def pop_front(array):
    if len(array) == 0:
        return 'None'
    return array.pop(0)


def peek_front(array):
    if len(array) == 0:
        return 'None'
    return array[0]


def is_end(components, transformed_sql, is_root_processed):
    end = False
    c = pop_front(components)
    c_instance = eval(c)

    if isinstance(c_instance, Root) and is_root_processed:
        # intersect, union, except
        end = True
    elif isinstance(c_instance, Filter):
        if 'where' not in transformed_sql:
            end = True
        else:
            num_conjunction = 0
            for f in transformed_sql['where']:
                if isinstance(f, str) and (f == 'and' or f == 'or'):
                    num_conjunction += 1
            current_filters = len(transformed_sql['where'])
            valid_filters = current_filters - num_conjunction
            if valid_filters >= num_conjunction + 1:
                end = True
    elif isinstance(c_instance, Order):
        if 'order' not in transformed_sql:
            end = True
        elif len(transformed_sql['order']) == 0:
            end = False
        else:
            end = True
    elif isinstance(c_instance, Sup):
        if 'sup' not in transformed_sql:
            end = True
        elif len(transformed_sql['sup']) == 0:
            end = False
        else:
            end = True
    components.insert(0, c)
    return end


def _transform(components, transformed_sql, col_set, table_names, values, schema):
    processed_root = False
    current_table = schema

    while len(components) > 0:
        if is_end(components, transformed_sql, processed_root):
            break
        c = pop_front(components)
        c_instance = eval(c)

        if isinstance(c_instance, Root):
            processed_root = True
            # a list with only 2 elements - similar to the spider data structure
            # [0] is used for distinct true/false. [1] contains the array with selection columns
            transformed_sql['select'] = [False, list()]
            if c_instance.id_c == 0:
                transformed_sql['where'] = list()
                transformed_sql['sup'] = list()
            elif c_instance.id_c == 1:
                transformed_sql['where'] = list()
                transformed_sql['order'] = list()
            elif c_instance.id_c == 2:
                transformed_sql['sup'] = list()
            elif c_instance.id_c == 3:
                transformed_sql['where'] = list()
            elif c_instance.id_c == 4:
                transformed_sql['order'] = list()
        elif isinstance(c_instance, Sel):
            if c_instance.id_c == 1:
                transformed_sql['select'][0] = True
        elif isinstance(c_instance, N):
            for i in range(c_instance.id_c + 1):
                agg = eval(pop_front(components))
                column = eval(pop_front(components))
                _table = pop_front(components)
                table = eval(_table)
                if not isinstance(table, T):
                    table = None
                    components.insert(0, _table)
                assert isinstance(agg, A) and isinstance(column, C)

                if table is not None:
                    col, _ = replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table)
                else:
                    col = col_set[column.id_c]

                # if we have an aggregation (e.g. COUNT(xy)) and the query as a whole is a distinct, we also use distinct
                # inside the aggregation (e.g. COUNT(DISTINCT xy)). While this is an oversimplification, it works for most cases
                # as it is hard to formulate a question in a way that some things are distinct and some are not.
                use_distinct = False
                if agg.id_c != 0 and transformed_sql['select'][0]:
                    use_distinct = True

                transformed_sql['select'][1].append((
                    agg.production.split()[1],
                    col,
                    table_names[table.id_c] if table is not None else table,
                    use_distinct
                ))

        elif isinstance(c_instance, Sup):
            transformed_sql['sup'].append(c_instance.production.split()[1])
            agg = eval(pop_front(components))
            column = eval(pop_front(components))
            _table = pop_front(components)
            table = eval(_table)
            if not isinstance(table, T):
                table = None
                components.insert(0, _table)
            assert isinstance(agg, A) and isinstance(column, C)

            transformed_sql['sup'].append(agg.production.split()[1])
            if table:
                column_final, _ = replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c],
                                                                current_table)
            else:
                column_final = col_set[column.id_c]
                raise RuntimeError('not found table !!!!')
            transformed_sql['sup'].append(column_final)
            transformed_sql['sup'].append(table_names[table.id_c] if table is not None else table)

        elif isinstance(c_instance, Order):
            transformed_sql['order'].append(c_instance.production.split()[1])
            agg = eval(pop_front(components))
            column = eval(pop_front(components))
            _table = pop_front(components)
            table = eval(_table)
            if not isinstance(table, T):
                table = None
                components.insert(0, _table)
            assert isinstance(agg, A) and isinstance(column, C)
            transformed_sql['order'].append(agg.production.split()[1])
            transformed_sql['order'].append(
                replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table)[0])
            transformed_sql['order'].append(table_names[table.id_c] if table is not None else table)

        elif isinstance(c_instance, Filter):

            op = c_instance.production.split()[1]
            if op == 'and' or op == 'or':
                transformed_sql['where'].append(op)
            else:
                # No sub-query
                agg = eval(pop_front(components))
                column = eval(pop_front(components))
                _table = pop_front(components)
                table = eval(_table)
                if not isinstance(table, T):
                    table = None
                    components.insert(0, _table)
                assert isinstance(agg, A) and isinstance(column, C)

                # here we verify if there is a sub- query
                component = peek_front(components)
                if isinstance(eval(component), V):
                    if table:
                        column_final, column_final_idx = replace_col_with_original_col(col_set[column.id_c],
                                                                                       table_names[table.id_c],
                                                                                       current_table)
                    else:
                        column_final = col_set[column.id_c]
                        raise RuntimeError('Table not found!')

                    second_value = None
                    # now we handle the values - we also handle data types properly.
                    value_obj = eval(pop_front(components))
                    value = format_value_given_datatype(column_final_idx, schema, values[value_obj.id_c], c_instance)

                    # there is a few special cases where we have to deal with multiple values - e.g. in the "X BETWEEN Y AND Z" case.
                    if isinstance(eval(peek_front(components)), V):
                        #group
                        #by ?t1_document_id
                        #Having(?agg >= 1 & & ?agg <= 2)
                        second_value_obj = eval(pop_front(components))
                        second_value = format_value_given_datatype(column_final_idx, schema,
                                                                   values[second_value_obj.id_c], c_instance)

                    transformed_sql['where'].append((
                        op,
                        agg.production.split()[1],
                        column_final,
                        table_names[table.id_c] if table is not None else table,
                        value,
                        second_value,
                    ))


                else:
                    # sub-query
                    new_dict = dict()
                    new_dict['sql'] = transformed_sql['sql']
                    transformed_sql['where'].append((
                        op,
                        agg.production.split()[1],
                        replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table)[0],
                        table_names[table.id_c] if table is not None else table,
                        _transform(components, new_dict, col_set, table_names, values, schema),
                        None
                    ))

    return transformed_sql


def transformSparql(query, schema, origin=None):
    preprocess_schema(schema)
    if origin is None:
        lf = query['model_result_replace']
    else:
        lf = origin
    # lf = query['rule_label']
    col_set = query['col_set']
    table_names = query['table_names']
    values = query['values']
    current_table = schema

    current_table['schema_content_clean'] = [x[1] for x in current_table['column_names']]
    current_table['schema_content'] = [x[1] for x in current_table['column_names_original']]

    components = split_logical_form(lf)

    transformed_sql = dict()
    transformed_sql['sql'] = query
    c = pop_front(components)
    c_instance = eval(c)
    assert isinstance(c_instance, Root1)
    if c_instance.id_c == 0:
        transformed_sql['intersect'] = dict()
        transformed_sql['intersect']['sql'] = query
        _transform(components, transformed_sql, col_set, table_names, values, schema)
        _transform(components, transformed_sql['intersect'], col_set, table_names, values, schema)
    elif c_instance.id_c == 1:
        transformed_sql['union'] = dict()
        transformed_sql['union']['sql'] = query
        _transform(components, transformed_sql, col_set, table_names, values, schema)
        _transform(components, transformed_sql['union'], col_set, table_names, values, schema)
    elif c_instance.id_c == 2:
        transformed_sql['except'] = dict()
        transformed_sql['except']['sql'] = query
        _transform(components, transformed_sql, col_set, table_names, values, schema)
        _transform(components, transformed_sql['except'], col_set, table_names, values, schema)
    else:
        _transform(components, transformed_sql, col_set, table_names, values, schema)

    parse_result = to_str(transformed_sql, 1, schema)

    parse_result = parse_result.replace('\t', '')
    return [parse_result]


# table names are created here
def col_to_str(agg, col, tab, table_names, N=1, is_distinct=False):
    global lastUsedCol
    _col = col.replace(' ', '_')
    distinct_str = 'DISTINCT' if is_distinct else ''
    if agg == 'none':
        if tab not in table_names:
            table_names[tab] = 'T' + str(len(table_names) + N)
        table_alias = table_names[tab]
        if col == '*':
            return '*'
        lastUsedCol = '?%s_%s' % (table_alias, _col)
        return '?%s_%s' % (table_alias, _col)
    else:
        if col == '*':
            if tab is not None and tab not in table_names:
                table_names[tab] = 'T' + str(len(table_names) + N)
            lastUsedCol = '?aggregation_all'
            return '(%s(%s %s) as ?%s_%s)' % (agg, distinct_str, _col, "aggregation", "all")
        else:
            if tab not in table_names:
                table_names[tab] = 'T' + str(len(table_names) + N)
            table_alias = table_names[tab]
            aggRandom = random.randint(0, 500)
            lastUsedCol = '?%s_%s_%s_%s' % ("aggregation", table_alias, _col, str(aggRandom))
            return '(%s(%s ?%s_%s) as ?%s_%s_%s_%s )' % (
            agg, distinct_str, table_alias, _col, "aggregation", table_alias, _col, str(
                aggRandom))


def replace_col_with_original_col(query, col, current_table):
   # print(query, col)
    if query == '*':
        return query, None

    cur_table = col
    cur_col = query
    single_final_col = None
    single_final_col_idx = None
    # print(query, col)
    for col_ind, col_name in enumerate(current_table['schema_content_clean']):
        if col_name == cur_col:
            assert cur_table in current_table['table_names']
            if current_table['table_names'][current_table['col_table'][col_ind]] == cur_table:
                single_final_col = current_table['column_names_original'][col_ind][1]
                single_final_col_idx = col_ind
                break
    # SELECT package_option ,  series_name FROM TV_Channel WHERE high_definition_TV  =  "yes"
    try:
        assert single_final_col
    except Exception as e:
        print(e)
    # if query != single_final_col:
    #     print(query, single_final_col)
    return single_final_col, single_final_col_idx


def build_graph(schema):
    relations = list()
    foreign_keys = schema['foreign_keys']
    for (fkey, pkey) in foreign_keys:
        fkey_table = schema['table_names_original'][schema['column_names'][fkey][0]]
        fkey_original_name = schema['column_names_original'][fkey][1]

        pkey_table = schema['table_names_original'][schema['column_names'][pkey][0]]
        pkey_original_name = schema['column_names_original'][pkey][1]

        relations.append((fkey_table, fkey_original_name, pkey_table, pkey_original_name))
        relations.append((pkey_table, pkey_original_name, fkey_table, fkey_original_name))

    return Graph(relations)


# takes the info for one database from the schemas var
def preprocess_schema(schema):
    tmp_col = []
    for cc in [x[1] for x in schema['column_names']]:
        if cc not in tmp_col:
            tmp_col.append(cc)
    schema['col_set'] = tmp_col
    # print table
    schema['schema_content'] = [col[1] for col in schema['column_names']]
    schema['col_table'] = [col[0] for col in schema['column_names']]
    graph = build_graph(schema)
    schema['graph'] = graph


def alias(select_clause):
    aliases = []
    # aliases= ['?col' + str(idx+1) for idx,val in enumerate(select_clause) if idx <= len(select_clause) and select_clause]
    for idx, val in enumerate(select_clause):
        if '*' not in select_clause[0] and 'count' not in select_clause[0]:
            if idx <= len(select_clause):
                alias = '?col' + str(idx + 1)
            aliases.append(alias)
        else:
            aliases = select_clause
    aliases = ''.join(aliases)
    return aliases


def format_value_given_datatype(column_final_idx, schema, value, filter_instance):
    # there is some special cases on the training set where the user asks for an "empty" history or
    # similar. In that case, we need to really add an empty string.
    if value == '':
        return "\'\'"

    # before we handle the values, we wanna find out what data-type the value has (based on the schema)
    if column_final_idx is not None:
        data_type = schema['column_types'][column_final_idx]
        if data_type == 'text':
            use_quotes = True
        elif data_type == 'time':
            use_quotes = True
        else:
            use_quotes = False
    else:
        # this means we are comparing the value with an aggregation - e.g. a COUNT(*). So it must be a number.
        use_quotes = False

    # sometimes a column is text, but the value in it is a number. We then have to remove the floating point, as
    # otherwise, the comparison is wrong. Example: a boolean column (as VARCHAR(1), you have compare 1 with '1' and not with '1.0'
    # for every other numeric comparison, the float/int difference is handled properly by SQL.
    if use_quotes and isinstance(value, numbers.Number):
        value = int(value)

    # filter 9 is 'LIKE' - we need to add the wildcards to the value.
    # TODO: introduce new filter actions for LIKE_FUZZY_BEGINNING and LIKE_FUZZY_ENDING
    if filter_instance.id_c == 9:
        # value = f'%{value}%'
        value = f'{value}'

    value_formatted = "'{}'".format(value) if use_quotes else value

    return value_formatted


# N_T is the number of tables
def to_str(sql_json, N_T, schema, pre_table_names=None):
    all_columns = list()
    select_clause = list()
    table_names = dict()
    conjunctions = list()
    filters = list()
    filter_not_exists = list()
    filter_subqueries = list()
    filter_not_in = list()
    where_clause = ''
    sup_clause = ''
    order_clause = ''
    direction_map = {"des": 'DESC', 'asc': 'ASC'}

    current_table = schema

    whole_select_distinct = 'DISTINCT ' if sql_json['select'][0] or 'intersect' in sql_json else ''

    for (agg, col, tab, distinct) in sql_json['select'][1]:
        all_columns.append((agg, col, tab))
        select_clause.append(col_to_str(agg, col, tab, table_names, N_T, distinct))
        #if 'sup' in sql_json:
            #if ('count' in sql_json['sup'] and '(count( *) as ?aggregation_all)' not in select_clause):
             #   select_clause.append('(count( *) as ?aggregation_all)')
    global nested_query_filter
    if agg in select_clause[0]:
        nested_query_filter = lastUsedCol
    else:
        nested_query_filter = select_clause[0]
    global table_names_for_intersect
    table_names_for_intersect = table_names
    global column_not_in
    # at this point the select clause is just the name of the projected columns
    # for union and except, the subquery column names need to be aliased as the samename, in order to work.
    if 'Root1(1)' in sql_json['sql']['model_result'] or 'Root1(2)' in sql_json['sql']['model_result']:
        # checking how to alias the columns, if there are more than 2 columns in the select then they both need to be aliased.
        aliased_select_clause = list()
        if select_clause[0] == '*':
            select_clause_str = 'SELECT ' + whole_select_distinct + ' '.join(select_clause).strip()
            select_clause_str += ' WHERE {'
        else:
            i = 1
            for column_name in select_clause:
                column_name = '(' + column_name + ' as ?col' + str(i) + ')'
                aliased_select_clause.append(column_name)
                i += 1
            select_clause_str = 'SELECT ' + whole_select_distinct + ' '.join(aliased_select_clause).strip()
            select_clause_str += ' WHERE {'
    # for Intersect
    elif 'Root1(0)' in sql_json['sql']['model_result'] and whole_select_distinct == '':
        select_clause_str = ''
    elif select_clause[0] == '*':
        # for cases like SELECT * from TABLE, in sparql it would otherwise return the ?s?p?o instead of just the literals
        select_clause_str = 'SELECT ?literal WHERE {'
    else:
        select_clause_str = 'SELECT ' + whole_select_distinct + ' '.join(
            select_clause).strip()  # removed comma for column projection
        select_clause_str += ' WHERE {'

    if 'sup' in sql_json:
        (direction, agg, col, tab,) = sql_json['sup']
        all_columns.append((agg, col, tab))
        sup_subject = col_to_str(agg, col, tab, table_names, N_T)
        # subject = col_to_str(agg, col, tab, table_names, N_T)
        if sup_subject == '(count( *) as ?aggregation_all)':
            sup_clause = ('ORDER BY %s (count(*)) LIMIT 1' % (direction_map[direction])).strip()
        elif '(count( *) as ?aggregation_all)' in select_clause and sup_subject not in select_clause:
            sup_clause = ('GROUP BY %s ORDER BY %s (%s) LIMIT 1' % (sup_subject, direction_map[direction], sup_subject)).strip()
        elif agg == 'sum':
            sup_subject= re.sub(r'as.*',')', sup_subject)
            sup_clause = ('ORDER BY %s %s LIMIT 1' % (
                direction_map[direction], sup_subject)).strip()
        else:
            sup_clause = ('ORDER BY %s (%s) LIMIT 1' % (
            direction_map[direction], sup_subject)).strip()  # changed order of vars
    elif 'order' in sql_json:
        (direction, agg, col, tab,) = sql_json['order']
        all_columns.append((agg, col, tab))
        subject = col_to_str(agg, col, tab, table_names, N_T)
        if subject == '(count( *) as ?aggregation_all)':
            subject=('count(*)')
            order_clause = ('ORDER BY %s (%s)' % (direction_map[direction], subject)).strip()  # changed order of vars
        else:
            order_clause = ('ORDER BY %s (%s)' % (direction_map[direction], subject)).strip()
    has_group_by = False
    have_clause = ''

    if 'where' in sql_json:
        for f in sql_json['where']:
            if isinstance(f, str):
                conjunctions.append(f)
            else:
                # parses the where clause into the following components... value for a subquery, such as with 'not_in'
                op, agg, col, tab, value, value2 = f
                all_columns.append((agg, col, tab))
                subject = col_to_str(agg, col, tab, table_names, N_T)

                # here we detect the difference between a simple value or a value which refers to a subquery
                if not isinstance(value, dict):
                    column_not_in = ''.join(select_clause)
                    values_combined = value
                    # sparql equivalent is contains and the syntax is slightly different.
                    if op == 'contains':
                        values_combined = "{}".format(value)
                        filters.append('%s (%s, %s)' % (op, subject.lower(), values_combined))
                    elif value2 is not None:  # right now the only case where two values are allowed is a BETWEEN X AND Y statement.
                        if subject.startswith('(count'):
                            subject_2 = subject[1:]
                            values_combined = "{} >= {} && {} <= {}".format(subject.lower(), value, subject_2.lower(),value2)
                        else:
                            values_combined = "{} >= {} && {} <= {}".format(subject.lower(), value, subject.lower(), value2)
                        filters.append(values_combined)
                    else:
                        filters.append('%s %s %s' % (subject.lower(), op, values_combined))  # removed ? here
                else:
                    column_not_in = subject.lower()
                    value['sql'] = sql_json['sql']

                    # This is kind of a style-change: instead of "xy IN (SELECT z FROM a)" one can also rewrite the query to a simple JOIN (this is done with adding the table).
                    # While it is critical for Exact Matching (where joins are checked), it is also necessary for Execution Accuracy. The reason is simply that a "EXISTS (SELECT ID FROM ...)" or a "IN (SELECT ID FROM ...)"
                    # behave slightly different than a JOIN, especially when it comes to their shortcut-behaviour and therefore also the natural order by (read more here) https://blog.jooq.org/2016/03/09/sql-join-or-exists-chances-are-youre-doing-it-wrong/

                    # as we choose to model simple joins (see model_simple_joins_as_filter.py) with a filter, we use the opportunity here to change the filter back to a normal join - as in the ground truth expected.
                    # Obviously this would be an issue if a query tries to exactly use "IN (SELECT ID FROM XY)". There is a few examples where this is the case, e.g.
                    # SELECT avg(age) FROM Dogs WHERE dog_id IN ( SELECT dog_id FROM Treatments ) (look for this in ground truth)

                    number_of_selects = len(value['select'][1])
                    # here it shows if there is an aggregation in the select or not, name of the column, name of the table, false? not sure what this is for.
                    first_select_aggregation = value['select'][1][0][0]

                    if op == 'in' and number_of_selects == 1 \
                            and first_select_aggregation == 'none' \
                            and 'where' not in value \
                            and 'order' not in value \
                            and 'sup' not in value:

                        if value['select'][1][0][2] not in table_names:
                            table_names[value['select'][1][0][2]] = 'T' + str(len(table_names) + N_T)

                        # This is necessary to avoid incorrect queries: if there is an "and/or" conjunction at the end of the filter, we need to put a next filter to avoid an invalid query.
                        # If we though apply a join instead of an "IN ()" statement, we need to remove that conjunction.
                        if len(conjunctions) > 0:
                            conjunctions.pop()

                        filters.append(None)
                    else:
                        if op == 'not_in':
                            filter_not_exists.append(
                                'Filter not exists {{' + to_str(value, len(table_names) + N_T + 20, schema) + '}')
                            if len(filter_not_exists) == 1:
                                # when there is also a filter in the outside query where is longer than 1
                                if isinstance(sql_json['where'][0], tuple) and 'where' not in sql_json['where'][0][4]:
                                    # and
                                    filter_not_in = 'FILTER ({} in ({}))'.format(subject.lower(),
                                                                                 lastUsedCol.lower()) + '}'
                                else:
                                    filter_not_in = 'FILTER ({} in ({}))'.format(subject.lower(),
                                                                                 column_not_in.lower()) + '}'
                            else:
                                filter_not_in = ''
                        else:
                            filter_subqueries.append('{' + to_str(value, len(table_names) + N_T + 20, schema) + '}')
                            # for every other sub-query we use a recursion to transform it to a string. By using a high NT-value we avoid conflicts with table-aliases.
                            filters.append(
                                '%s %s %s' % (subject.lower(), op, nested_query_filter.lower()))  # removed ? here too

                if len(conjunctions):
                    filters.append(conjunctions.pop())

        aggs = ['count(', 'avg(', 'min(', 'max(', 'sum(']
        having_filters = list()
        idx = 0
        while idx < len(filters):
            _filter = filters[idx]
            if _filter is None:
                idx += 1
                continue
            for agg in aggs:
                if _filter.lower().startswith(agg, 1):
                    # because count in sparql looks like this '(count( *) as ?aggregation_all) > 1', starts with needs to start from index 1 to check if there are aggregations
                    if agg != 'count(':
                        _filter = re.sub(r'as(.*?)\)', '', _filter)
                    having_filters.append(_filter)
                    filters.pop(idx)
                    # print(filters)
                    if 0 < idx and (filters[idx - 1] in ['and', 'or']):
                        filters.pop(idx - 1)
                        # print(filters)
                    break
            else:
                idx += 1
        if len(having_filters) > 0:
            have_clause = 'HAVING' + ' '.join(having_filters).strip() + ')'
            have_clause = have_clause.replace('as ?aggregation_all)', '')
        if len(filters) > 0:
            # print(filters)
            filters = [_f for _f in filters if _f is not None]
            conjun_num = 0
            filter_num = 0
            for _f in filters:
                if _f in ['or', 'and']:
                    conjun_num += 1
                else:
                    filter_num += 1
            if conjun_num > 0 and filter_num != (conjun_num + 1):
                # assert 'and' in filters
                idx = 0
                while idx < len(filters):
                    if filters[idx] == 'and':
                        if idx - 1 == 0:
                            filters.pop(idx)
                            break
                        if filters[idx - 1] in ['and', 'or']:
                            filters.pop(idx)
                            break
                        if idx + 1 >= len(filters) - 1:
                            filters.pop(idx)
                            break
                        if filters[idx + 1] in ['and', 'or']:
                            filters.pop(idx)
                            break
                    idx += 1
            if len(filters) > 0:
                where_clause = 'FILTER(' + ' '.join(filters).strip() + ') . '
            else:
                where_clause = ''

        if len(having_filters) > 0:
            has_group_by = True

    for agg in ['count(', 'avg(', 'min(', 'max(', 'sum(']:
        if (len(sql_json['select'][1]) > 1 and agg in select_clause_str.lower()) \
                or agg in sup_clause or agg in order_clause or ('sup' in sql_json and agg in sup_subject):
            has_group_by = True
            break

    group_by_clause = ''
    if has_group_by:
        if len(table_names) == 1:
            # check none agg
            is_agg_flag = False
            all_group_by_columns= []
            #i = 0
            for (agg, col, tab, _) in sql_json['select'][1]:
                if agg == 'none':
                    group_by_column = col_to_str(agg, col, tab, table_names, N_T)
                    all_group_by_columns.append(group_by_column)
                    group_by_clause = 'GROUP BY ' + ' '.join(column for column in all_group_by_columns)
                else:
                    is_agg_flag = True
                # if i <= len(sql_json['select'][1]):
                #     if agg == 'none':
                #         group_by_column = col_to_str(agg, col, tab, table_names, N_T)
                #         all_group_by_columns.append(group_by_column)
                #         i +=1
                #     else:
                #         is_agg_flag = True
                #         i+=1
                # elif all_group_by_columns is not None:
                #     group_by_clause = 'GROUP BY ' + ' '.join(column for column in all_group_by_columns)
            if is_agg_flag is False and len(group_by_clause) > 5:
                group_by_clause = "GROUP BY"
                for (agg, col, tab, _) in sql_json['select'][1]:
                    group_by_clause = group_by_clause + ' ' + col_to_str(agg, col, tab, table_names, N_T)
                # remove the last comma
                #group_by_clause = group_by_clause.replace['']
                #< 5 bc if it comes here the group by is empty. -Kate
            if len(group_by_clause) < 5:
                if 'COUNT(*)' in select_clause_str:
                    # it doesn't seem like this case is never actually used...-Kate
                    current_table = schema
                    for primary in current_table['primary_keys']:
                        if current_table['table_names'][current_table['col_table'][primary]] in table_names:
                            group_by_clause = 'GROUP BY ' + col_to_str('none', current_table['schema_content'][primary],
                                                                       current_table['table_names'][
                                                                           current_table['col_table'][primary]],
                                                                       table_names, N_T)
        else:
            group_by_columns= list()
            i = 0
            for projection in sql_json['select'][1]:
                if projection[0] == 'none':
                    if select_clause[i].startswith('?'):
                        group_by_columns.append(select_clause[i])
                    else:
                        i += 1
                        group_by_columns.append(select_clause[i])
                i += 1

            group_by_clause = 'GROUP BY ' + ' '.join(column for column in group_by_columns)
                # for key, value in table_names.items():
                #     if key not in non_lists:
                #         non_lists.append(key)
                #
                # a = non_lists[0]
                # b = None
                # for non in non_lists:
                #     if a != non:
                #         b = non
                # if b:
                #     for pair in current_table['foreign_keys']:
                #         t1 = current_table['table_names'][current_table['col_table'][pair[0]]]
                #         t2 = current_table['table_names'][current_table['col_table'][pair[1]]]
                #         if t1 in [a, b] and t2 in [a, b]:
                #             if pre_table_names and t1 not in pre_table_names:
                #                 assert t2 in pre_table_names
                #                 t1 = t2
                #             group_by_clause = 'GROUP BY ' + col_to_str('none',
                #                                                        current_table['schema_content'][pair[0]],
                #                                                        t1,
                #                                                        table_names, N_T)
                #             fix_flag = True
                #             break
                #
                # if fix_flag is False:
                #     agg, col, tab, _ = sql_json['select'][1][0]
                #     group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T)

            # else:
            #     # check if there is only one non agg
            #     non_agg, non_agg_count = None, 0
            #     non_lists = []
            #     for (agg, col, tab, _) in sql_json['select'][1]:
            #         if agg == 'none':
            #             non_agg = (agg, col, tab)
            #             non_lists.append(tab)
            #             non_agg_count += 1
            #
            #     non_lists = list(set(non_lists))
            #     # print(non_lists)
            #     if non_agg_count == 1:
            #         group_by_clause = 'GROUP BY ' + col_to_str(non_agg[0], non_agg[1], non_agg[2], table_names, N_T)
            #     elif non_agg:
            #         find_flag = False
            #         fix_flag = False
            #         find_primary = None
            #         if len(non_lists) <= 1:
            #             for key, value in table_names.items():
            #                 if key not in non_lists:
            #                     non_lists.append(key)
            #         if len(non_lists) > 1:
            #             a = non_lists[0]
            #             b = None
            #             for non in non_lists:
            #                 if a != non:
            #                     b = non
            #             if b:
            #                 for pair in current_table['foreign_keys']:
            #                     t1 = current_table['table_names'][current_table['col_table'][pair[0]]]
            #                     t2 = current_table['table_names'][current_table['col_table'][pair[1]]]
            #                     if t1 in [a, b] and t2 in [a, b]:
            #                         if pre_table_names and t1 not in pre_table_names:
            #                             assert t2 in pre_table_names
            #                             t1 = t2
            #                         group_by_clause = 'GROUP BY ' + col_to_str('none',
            #                                                                    current_table['schema_content'][pair[0]],
            #                                                                    t1,
            #                                                                    table_names, N_T)
            #                         fix_flag = True
            #                         break
            #         tab = non_agg[2]
            #         assert tab in current_table['table_names']
            #
            #         for primary in current_table['primary_keys']:
            #             if current_table['table_names'][current_table['col_table'][primary]] == tab:
            #                 find_flag = True
            #                 find_primary = (current_table['schema_content'][primary], tab)
            #         if fix_flag is False:
            #             if find_flag is False:
            #                 # rely on count *
            #                 foreign = []
            #                 for pair in current_table['foreign_keys']:
            #                     if current_table['table_names'][current_table['col_table'][pair[0]]] == tab:
            #                         foreign.append(pair[1])
            #                     if current_table['table_names'][current_table['col_table'][pair[1]]] == tab:
            #                         foreign.append(pair[0])
            #
            #                 for pair in foreign:
            #                     if current_table['table_names'][current_table['col_table'][pair]] in table_names:
            #                         group_by_clause = 'GROUP BY ' + col_to_str('none',
            #                                                                    current_table['schema_content'][pair],
            #                                                                    current_table['table_names'][
            #                                                                        current_table['col_table'][pair]],
            #                                                                    table_names, N_T)
            #                         find_flag = True
            #                         break
            #                 if find_flag is False:
            #                     for (agg, col, tab, _) in sql_json['select'][1]:
            #                         if 'id' in col.lower():
            #                             group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T)
            #                             break
            #                     if len(group_by_clause) > 5:
            #                         pass
            #                     else:
            #                         raise RuntimeError('fail to convert')
            #             else:
            #                 group_by_clause = 'GROUP BY ' + col_to_str('none', find_primary[0],
            #                                                            find_primary[1],
            #                                                            table_names, N_T)
    intersect_clause = ''
    if 'intersect' in sql_json:
        sql_json['intersect']['sql'] = sql_json['sql']
        intersect_clause = '{' + to_str(sql_json['intersect'], len(table_names) + 1, schema, table_names) +'}'
        filter_columns = []
        for (agg, col, tab, distinct) in sql_json['intersect']['select'][1]:
            filter_columns.append(col_to_str(agg, col, tab, table_names_for_intersect, len(table_names) + 1, distinct))

        for index in range(len(filter_columns)):
            where_clause = where_clause + 'FILTER ({} in ({})) . '.format(select_clause[index].lower(),
                                                                          filter_columns[index].lower())

    union_clause = ''
    if 'union' in sql_json:
        sql_json['union']['sql'] = sql_json['sql']
        # union_clause = 'UNION {' + to_str(sql_json['union'], len(table_names) + 1, schema, table_names) #to reset table alias count to 1 for union clauses
        union_clause = 'UNION {' + to_str(sql_json['union'], 1, schema, table_names)
    except_clause = ''
    if 'except' in sql_json:
        sql_json['except']['sql'] = sql_json['sql']
        except_clause = 'MINUS {' + to_str(sql_json['except'], len(table_names) + 1, schema, table_names)
    # print(current_table['table_names_original'])
    table_names_replace = {}
    for a, b in zip(current_table['table_names_original'], current_table['table_names']):
        table_names_replace[b] = a
    new_table_names = {}
    for key, value in table_names.items():
        if key is None:
            continue
        new_table_names[table_names_replace[key].lower()] = value
    from_clause = infer_from_clause(new_table_names, schema, all_columns).strip()
    #    sql = select_clause_str + agg so that the count aggregation appears in the select clause and NOT after the order by and desc.
    # here checks if the union clause is non or if it is an empty string

    # the first case is for intersect type clauses ie Root1(0)

    if 'Root1(0)' in sql_json['sql']['model_result'] and whole_select_distinct == '':
        sql = ' '.join([select_clause_str.lower(), from_clause.lower(), intersect_clause, where_clause,
                        union_clause, except_clause, order_clause.lower(),
                        sup_clause.lower()])
    elif (union_clause is None or union_clause.strip() == "") and (
            except_clause is None or except_clause.strip() == ""):
        sql = ' '.join([select_clause_str.lower(), from_clause.lower(), where_clause, intersect_clause,
                        union_clause.lower(), except_clause, ' '.join(filter_not_exists), ''.join(filter_not_in),
                        ' '.join(filter_subqueries), "}", group_by_clause.lower(), order_clause.lower(),
                        sup_clause.lower(),
                        have_clause.lower()])
    # #if there is a union clause or an except clause, there needs to be an additional select statement.sup_clause.lower(), have_clause.lower()])
    # is elif or else for all of the normal queries
    else:
        # column names in the select statements for except and union type statements need to be aliased
        aliases = alias(select_clause)
        sql = ' '.join(
            ["SELECT distinct" + aliases + " WHERE { {", select_clause_str.lower(), from_clause.lower(), where_clause,
             intersect_clause, "}}", union_clause, except_clause, "} }", group_by_clause.lower(), order_clause.lower(),
             sup_clause.lower(), have_clause.lower()])

    # improvement putting curly bracket in from clause and in general the order of the statements
    sql = sql.replace(' and ', ' && ').replace(' or ', ' || ')
    return sql


def transform_semQL_to_sparql(schemas, sem_ql_prediction, output_dir):
    # schemas contains all of the database info, taken from the db json files
    # TODO: find out if this adds any benefit for the trained models. If we run it with the ground truth (so no prediction, just SQL -> SemQL -> SQL) it is even slightly better without it.
    # alter_not_in(sem_ql_prediction, schemas=schemas)
    # alter_inter(sem_ql_prediction)
    alter_column0(sem_ql_prediction)
    index = range(len(sem_ql_prediction))
    count = 0
    exception_count = 0
    with open(os.path.join(output_dir, 'output.txt'), 'w', encoding='utf8') as d, open(
            os.path.join(output_dir, 'ground_truth.txt'), 'w', encoding='utf8') as g:
        for i in index:
            try:
                result = transformSparql(sem_ql_prediction[i], schemas[sem_ql_prediction[i]['db_id']])
                d.write(result[0] + '\n')
                g.write("%s\t%s\t%s\n" % (
                sem_ql_prediction[i]['query'], sem_ql_prediction[i]["db_id"], sem_ql_prediction[i]["question"]))
                count += 1
            except Exception as e:
                # This origin seems to be the fallback-query. Not sure how we come up with it, most probably it's just a dummy query to fill in a result for each example.
                result = transformSparql(sem_ql_prediction[i], schemas[sem_ql_prediction[i]['db_id']],
                                         origin='Root1(3) Root(5) Sel(0) N(0) A(3) C(0) T(0)')
                exception_count += 1
                d.write(result[0] + '\n')
                g.write("%s\t%s\t%s\n" % (
                sem_ql_prediction[i]['query'], sem_ql_prediction[i]["db_id"], sem_ql_prediction[i]["question"]))
                count += 1
                print(e)
                print('Exception')
                print(traceback.format_exc())
                print(sem_ql_prediction[i]['question'])
                print(sem_ql_prediction[i]['query'])
                print(sem_ql_prediction[i]['db_id'])
                print('===\n\n')

    return count, exception_count