ValueNet4SPARQL / src / named_entity_recognition / pre_process_ner_values.py
pre_process_ner_values.py
Raw
from pytictoc import TicToc

from nltk import ngrams

from named_entity_recognition.database_value_finder.database_value_finder_sqlite import DatabaseValueFinderSQLite
from named_entity_recognition.handcrafted_heuristics import find_values_in_quote, find_ordinals, \
    find_emails, find_genders, find_null_empty_values, find_variety_of_common_mentionings, find_special_codes, \
    find_single_letters, find_capitalized_words, find_months, find_location_abbreviations

from named_entity_recognition.ner_extraction_data_dto import NerExtractionData

all_database_value_finder = {}


def _get_or_create_value_finder(database, database_folder, db_schema):
    if database not in all_database_value_finder:
        all_database_value_finder[database] = DatabaseValueFinderSQLite(database_folder, database, db_schema)
    db_value_finder = all_database_value_finder[database]
    return db_value_finder


def pre_process_ner_candidates(ner_extracted_values, question, question_tokens):

    extracted_data = NerExtractionData([], [], [], [], [], [], [], [], [], [], [], [], [], [], [])

    extracted_data.heuristic_values_in_quote.extend(find_values_in_quote(question))
    extracted_data.heuristic_ordinals.extend(find_ordinals(question_tokens))
    extracted_data.heuristics_emails.extend(find_emails(question))
    extracted_data.heuristics_genders.extend(find_genders(question_tokens))
    extracted_data.heuristics_null_empty.extend(find_null_empty_values(question_tokens))
    extracted_data.heuristics_variety_common_mentionings.extend(find_variety_of_common_mentionings(question_tokens))
    extracted_data.heuristics_special_codes.extend(find_special_codes(question))
    extracted_data.heuristics_single_letters.extend(find_single_letters(question))
    extracted_data.heuristics_capitalized_words.extend(find_capitalized_words(question))
    extracted_data.heuristics_months.extend(find_months(question_tokens))
    extracted_data.heuristics_location_abbreviations.extend(find_location_abbreviations(question_tokens, question))

    for entity in ner_extracted_values:
        # for all types see https://cloud.google.com/natural-language/docs/reference/rest/v1beta2/Entity#Type
        # TODO: extend this pre-processing for e.g. ADDRESSES, PHONE_NUMBERS - see the link above.
        if entity['type'] == 'NUMBER':
            extracted_data.ner_numbers.append(_compose_number(entity))
        elif entity['type'] == 'DATE':
            extracted_data.ner_dates.extend(_compose_date(entity))
        elif entity['type'] == 'PRICE':
            extracted_data.ner_prices.append(_compose_price(entity))
        else:
            if len(entity['name'].split(' ')) == 1:
                # just take the extracted value - without any adaptions
                extracted_data.ner_remaining.append(entity['name'])
            else:
                # there are multiple words in this value - create combinations out of it.
                extracted_data.ner_remaining.extend(_build_ngrams(entity['name']))

    return extracted_data


def match_values_in_database(db_value_finder, extracted_data, include_primary_keys):

    # NOTE: adapting this thresholds should always be done empirically and is heavily depending on the chosen similarity metric.
    # have a look at the script find_optimal_similarity_threshold.py to find optimal thresholds.
    exact_match = db_value_finder.exact_match_threshold
    high_similarity = db_value_finder.high_similarity_threshold
    medium_similarity = db_value_finder.medium_similarity_threshold

    # depending on the candidate type we set a different tolerance value for similarity matching with db-values.
    # Remember: 1.0 is looking for exact matches only. Also remember: we do lower-case only comparison, so 'Male' and 'male' will match with 1.0
    candidates = []
    # With values in quote we are a bit tolerant. Important: we keep this values anyway, as the are often used in fuzzy LIKE searches.
    _add_without_duplicates([(quote, high_similarity) for quote in extracted_data.heuristic_values_in_quote], candidates)
    # Gender values we only want exact matches.
    _add_without_duplicates([(gender, exact_match) for gender in extracted_data.heuristics_genders], candidates)
    _add_without_duplicates([(common_mentionings, high_similarity) for common_mentionings in extracted_data.heuristics_variety_common_mentionings], candidates)
    # a special code should match exactly
    _add_without_duplicates([(special_code, exact_match) for special_code in extracted_data.heuristics_special_codes], candidates)
    _add_without_duplicates([(capitalized_word, medium_similarity) for capitalized_word in extracted_data.heuristics_capitalized_words], candidates)
    _add_without_duplicates([(location, high_similarity) for location in extracted_data.heuristics_location_abbreviations], candidates)

    # important: in addition to all the handcrafted features, also take all values from the NER which aren't known dates/numbers/prices
    _add_without_duplicates([(ner_value, medium_similarity) for ner_value in extracted_data.ner_remaining], candidates)

    _add_without_duplicates([(ordinal, exact_match) for ordinal in extracted_data.heuristic_ordinals], candidates)
    _add_without_duplicates([(email, high_similarity) for email in extracted_data.heuristics_emails], candidates)
    _add_without_duplicates([(single_letter, exact_match) for single_letter in extracted_data.heuristics_single_letters], candidates)
    _add_without_duplicates([(ner_date, exact_match) for ner_date in extracted_data.ner_dates], candidates)
    _add_without_duplicates([(ner_number, exact_match) for ner_number in extracted_data.ner_numbers], candidates)
    _add_without_duplicates([(ner_price, exact_match) for ner_price in extracted_data.ner_prices], candidates)

    tic_toc = TicToc()
    tic_toc.tic()
    print(f'Look for potential candidates "{candidates}" in database {db_value_finder.database} (include primary keys: {include_primary_keys})')
    matching_db_values = db_value_finder.find_similar_values_in_database(candidates, include_primary_keys)
    print(f'Confirmed the following candidates "{matching_db_values}"')
    tic_toc.toc()

    return matching_db_values


def _add_without_duplicates(new_candidates, candidates):
    for value, tolerance in new_candidates:
        existing_candidate = next(filter(lambda value_tolerance: value_tolerance[0] == value, candidates), None)
        if existing_candidate:
            existing_value, existing_tolerance = existing_candidate
            if existing_tolerance < tolerance:
                candidates.remove((existing_value, existing_tolerance))
                candidates.append((value, tolerance))
        else:
            candidates.append((value, tolerance))


def _compose_number(entity):
    # NUMBER will also detect e.g. a "one" and transform it to a 1 in the metadata
    value_as_string = entity['metadata']['value']

    if '.' in value_as_string:
        # Some floats from NER use trailing zeros - strip them away.
        return value_as_string.rstrip('0').rstrip('.')
    else:
        return value_as_string


def _compose_price(entity):
    # PRICE contains also the "currency" in the metadata. We assume that we don't need it.
    return entity['metadata']['value']


def _compose_date(entity):
    """
    This method formats returns a proper 'YYYY-MM-DD' string or a subpart of it (e.g. 'YYYY-MM') if not all information available.
    See Tests for more information.
    """
    full_date = ''
    if 'year' in entity['metadata']:
        full_date = entity['metadata']['year']

    if 'month' in entity['metadata']:
        if len(entity['metadata']['month']) == 1:
            month = '0' + entity['metadata']['month']
        else:
            month = entity['metadata']['month']

        if full_date:
            full_date = full_date + '-' + month
        else:
            full_date = month

    if 'day' in entity['metadata']:

        if len(entity['metadata']['day']) == 1:
            day = '0' + entity['metadata']['day']
        else:
            day = entity['metadata']['day']

        if full_date:
            full_date = full_date + '-' + day
        else:
            full_date = day

    # TODO: there is 4 cases where the database is a string instead of a date (e.g. "voter_2", "Voting_record" -->08/30/2015). Therefore we also deliver the value here. Fix the database.
    return [full_date, entity['name']]


def _build_ngrams(multi_token_input):
    combinations = [multi_token_input]

    # this is a rather simple splitt - might consider spaCy
    tokens = multi_token_input.split()

    for n in range(1, len(tokens)):
        # n-gram tuples can e.g. be ('John', 'Doe) and ('Doe', 'Smith')
        ngramm_tuples = ngrams(tokens, n)

        for t in ngramm_tuples:
            combinations.append(' '.join(t))

    return combinations


def add_non_found_values(expected_values, candidates_original):
    candidates = candidates_original.copy()

    all_found = True
    for value in expected_values:
        found = False
        for extracted_value in candidates:
            if _is_value_equal(extracted_value, value):
                found = True
                break

        if not found:
            all_found = False
            print(f"Could not find '{value}' in extracted values '{candidates}'. We add it from the ground truth.")
            candidates.append(value)

    return candidates, all_found, len(expected_values)


def _is_value_equal(extracted_value, expected_value):
    # there are some cases were we have a float stored in the ground truth - even though we are actually looking for an int.
    if isinstance(expected_value, float) and expected_value.is_integer():
        expected_value = int(expected_value)

    expected_value = str(expected_value)

    return expected_value == extracted_value


def compare_questions(row, ner_information):
    if len(row['question']) != len(ner_information['question']):
        return row['question'][:-2] != ner_information['question'][:-1]
    else:
        return row['question'] != ner_information['question']


# TODO: this part of the code is not used currently, as the pre-processing of ner-values is done as part of the standard pre-processing. See pre_process.py.
# if __name__ == '__main__':
#     arg_parser = argparse.ArgumentParser()
#     arg_parser.add_argument('--data_path', type=str, required=True)
#     arg_parser.add_argument('--ner_data_path', type=str, required=True)
#     arg_parser.add_argument('--output_path', type=str, required=True)
#     arg_parser.add_argument('--database_folder', type=str, default='data/spider/original/database')
#     arg_parser.add_argument('--database_schema', type=str, default='data/spider/original/tables.json')
#
#     args = arg_parser.parse_args()
#
#     with open(os.path.join(args.data_path), 'r', encoding='utf-8') as json_file:
#         data = json.load(json_file)
#
#     with open(os.path.join(args.ner_data_path), 'r', encoding='utf-8') as json_file:
#         ner_data = json.load(json_file)
#
#     assert len(data) == len(ner_data), 'Both, NER data and actual data (e.g. ner_train.json and preprocessed_train.json) need to have the same amount of rows!'
#
#     # add both, the ner-extracted values and the actual values (extracted from the SQL-ground truth) to the data file.
#     for idx, (row, ner_information) in enumerate(zip(data, ner_data)):
#         if compare_questions(row, ner_information):
#             print(f'{idx}       {row["question"]}               {ner_information["question"]}')
#         row['ner_extracted_values'] = ner_information['entities']
#         row['values'] = ner_information['values']
#
#     entry_with_values = 0
#     not_found_count = 0
#     total_expected_value_count = 0
#
#     # here we pre-process the NER results and add further values by handcrafted heuristics
#     extracted_values = [pre_process_ner_candidates(row) for idx, row in enumerate(data)]
#     print("Preprocessed all NER values and applied handcrafted handcrafted heuristics. "
#           "Next Step: matching values in database. This might take a while.")
#     print()
#
#     # here we takes the pre-processed values and try to match them in the database. As this process is very time-consuming,
#     # we parallelize it. Important: Parallel() maintains the order of the input data!
#     n_cores = multiprocessing.cpu_count()
#     values_matched_with_database = Parallel(n_jobs=n_cores)(
#         delayed(match_values_in_database)(row['db_id'], extracted_value, args.database_folder, args.database_schema) for extracted_value, row
#         in zip(extracted_values, data))
#     print("Scanned all databases for matching values.")
#     print()
#
#     for row, value_candidates in zip(data, values_matched_with_database):
#         # this method is basically cheating: if we don't find a value in the values candidates, we add it from the ground truth.
#         # This makes sense for training, as we don't want to reduce the training samples because of non-found values. We also mark
#         # the samples where not all values could get extracted, so we can manually fail them during evaluation.
#         value_candidates_adjusted, all_values_found, n_expected_values = add_non_found_values(row['values'],
#                                                                                               value_candidates,
#                                                                                               row['question'],
#                                                                                               row['query'],
#                                                                                               row['db_id'])
#
#         row['ner_extracted_values_processed'] = value_candidates_adjusted
#         row['all_values_found'] = all_values_found
#
#         if not all_values_found:
#             not_found_count += 1
#
#         if n_expected_values > 0:
#             entry_with_values += 1
#
#         total_expected_value_count += n_expected_values
#
#     print()
#     print()
#     print(
#         f"Could find all values in {len(data) - not_found_count} of {len(data)} examples. {entry_with_values} entries "
#         f"contain values, and in {not_found_count} we could't find them. There is a total of {total_expected_value_count} values in this dataset")
#
#     with open(os.path.join(args.output_path), 'w', encoding='utf-8') as f:
#         json.dump(data, f, indent=2)