"""Postprocessing for decoded lip-reading commands. Extracted from calculate_cer.py: token snapping with Levenshtein distance to closed vocabularies (verb / direction / day / number) + optional letter slots. """ from __future__ import annotations import numpy as np # ---- closed vocabularies (keep in sync with training/eval) ---- VERBS = ["potvrdi", "odustani", "obriši", "pošaqi", "daqe", "početak", "kraj"] LETTERS = [ "a", "b", "c", "č", "ć", "d", "x", "đ", "e", "f", "g", "h", "i", "j", "k", "l", "q", "m", "n", "w", "o", "p", "r", "s", "š", "t", "u", "v", "z", "ž" ] DIRECTIONS = ["napred", "nazad", "levo", "desno", "gore", "dole"] DAYS = ["ponedeqak", "utorak", "sreda", "četvrtak", "petak", "subota", "nedeqa"] NUMBERS = ["nula", "jedan", "dva", "tri", "četiri", "pet", "šest", "sedam", "osam", "devet"] def levenshtein(a: str, b: str) -> int: """Classic Levenshtein distance (edit distance).""" a, b = a.strip(), b.strip() if a == b: return 0 if len(a) == 0: return len(b) if len(b) == 0: return len(a) dp = np.zeros((len(a) + 1, len(b) + 1), dtype=np.int32) dp[:, 0] = np.arange(len(a) + 1) dp[0, :] = np.arange(len(b) + 1) for i in range(1, len(a) + 1): for j in range(1, len(b) + 1): cost = 0 if a[i - 1] == b[j - 1] else 1 dp[i, j] = min( dp[i - 1, j] + 1, # deletion dp[i, j - 1] + 1, # insertion dp[i - 1, j - 1] + cost # substitution ) return int(dp[-1, -1]) def closest_in_vocab(token: str, vocab: list[str]) -> str: token = token.strip() if not token: return token best = vocab[0] best_d = 1e9 for v in vocab: d = levenshtein(token, v) if d < best_d: best_d = d best = v return best def pick_best_for_vocab(tokens: list[str], used: set[int], vocab: list[str]): """Pick the token (by index) whose best match in vocab has smallest distance.""" best_idx = None best_word = vocab[0] best_d = 1e9 any_candidate = False for i, tok in enumerate(tokens): if i in used: continue any_candidate = True tok = tok.strip() if not tok: continue local_best = vocab[0] local_best_d = 1e9 for v in vocab: d = levenshtein(tok, v) if d < local_best_d: local_best_d = d local_best = v if local_best_d < best_d: best_d = local_best_d best_word = local_best best_idx = i if not any_candidate: return None, vocab[0] return best_idx, best_word def postprocess_command(pred: str) -> str: """Snap a raw decoded string into a structured command.""" tokens = pred.strip().split() if not tokens: return pred letters = [] non_letters = [] for tok in tokens: if len(tok) == 1: letters.append(tok) else: non_letters.append(tok) if not non_letters: return pred used: set[int] = set() verb_idx, verb = pick_best_for_vocab(non_letters, used, VERBS) if verb_idx is not None: used.add(verb_idx) dir_idx, direction = pick_best_for_vocab(non_letters, used, DIRECTIONS) if dir_idx is not None: used.add(dir_idx) day_idx, day = pick_best_for_vocab(non_letters, used, DAYS) if day_idx is not None: used.add(day_idx) num_idx, number = pick_best_for_vocab(non_letters, used, NUMBERS) if num_idx is not None: used.add(num_idx) if len(letters) == 0: out_tokens = [verb, direction, day, number] elif len(letters) == 1: L1 = closest_in_vocab(letters[0], LETTERS) out_tokens = [verb, L1, direction, day, number] else: L1 = closest_in_vocab(letters[0], LETTERS) L2 = closest_in_vocab(letters[1], LETTERS) out_tokens = [verb, L1, direction, L2, day, number] return " ".join(out_tokens)