"""Postprocessing for decoded lip-reading commands.
Extracted from calculate_cer.py: token snapping with Levenshtein distance to
closed vocabularies (verb / direction / day / number) + optional letter slots.
"""
from __future__ import annotations
import numpy as np
# ---- closed vocabularies (keep in sync with training/eval) ----
VERBS = ["potvrdi", "odustani", "obriši", "pošaqi", "daqe", "početak", "kraj"]
LETTERS = [
"a", "b", "c", "č", "ć", "d", "x", "đ", "e", "f",
"g", "h", "i", "j", "k", "l", "q", "m", "n", "w",
"o", "p", "r", "s", "š", "t", "u", "v", "z", "ž"
]
DIRECTIONS = ["napred", "nazad", "levo", "desno", "gore", "dole"]
DAYS = ["ponedeqak", "utorak", "sreda", "četvrtak", "petak", "subota", "nedeqa"]
NUMBERS = ["nula", "jedan", "dva", "tri", "četiri", "pet", "šest", "sedam", "osam", "devet"]
def levenshtein(a: str, b: str) -> int:
"""Classic Levenshtein distance (edit distance)."""
a, b = a.strip(), b.strip()
if a == b:
return 0
if len(a) == 0:
return len(b)
if len(b) == 0:
return len(a)
dp = np.zeros((len(a) + 1, len(b) + 1), dtype=np.int32)
dp[:, 0] = np.arange(len(a) + 1)
dp[0, :] = np.arange(len(b) + 1)
for i in range(1, len(a) + 1):
for j in range(1, len(b) + 1):
cost = 0 if a[i - 1] == b[j - 1] else 1
dp[i, j] = min(
dp[i - 1, j] + 1, # deletion
dp[i, j - 1] + 1, # insertion
dp[i - 1, j - 1] + cost # substitution
)
return int(dp[-1, -1])
def closest_in_vocab(token: str, vocab: list[str]) -> str:
token = token.strip()
if not token:
return token
best = vocab[0]
best_d = 1e9
for v in vocab:
d = levenshtein(token, v)
if d < best_d:
best_d = d
best = v
return best
def pick_best_for_vocab(tokens: list[str], used: set[int], vocab: list[str]):
"""Pick the token (by index) whose best match in vocab has smallest distance."""
best_idx = None
best_word = vocab[0]
best_d = 1e9
any_candidate = False
for i, tok in enumerate(tokens):
if i in used:
continue
any_candidate = True
tok = tok.strip()
if not tok:
continue
local_best = vocab[0]
local_best_d = 1e9
for v in vocab:
d = levenshtein(tok, v)
if d < local_best_d:
local_best_d = d
local_best = v
if local_best_d < best_d:
best_d = local_best_d
best_word = local_best
best_idx = i
if not any_candidate:
return None, vocab[0]
return best_idx, best_word
def postprocess_command(pred: str) -> str:
"""Snap a raw decoded string into a structured command."""
tokens = pred.strip().split()
if not tokens:
return pred
letters = []
non_letters = []
for tok in tokens:
if len(tok) == 1:
letters.append(tok)
else:
non_letters.append(tok)
if not non_letters:
return pred
used: set[int] = set()
verb_idx, verb = pick_best_for_vocab(non_letters, used, VERBS)
if verb_idx is not None:
used.add(verb_idx)
dir_idx, direction = pick_best_for_vocab(non_letters, used, DIRECTIONS)
if dir_idx is not None:
used.add(dir_idx)
day_idx, day = pick_best_for_vocab(non_letters, used, DAYS)
if day_idx is not None:
used.add(day_idx)
num_idx, number = pick_best_for_vocab(non_letters, used, NUMBERS)
if num_idx is not None:
used.add(num_idx)
if len(letters) == 0:
out_tokens = [verb, direction, day, number]
elif len(letters) == 1:
L1 = closest_in_vocab(letters[0], LETTERS)
out_tokens = [verb, L1, direction, day, number]
else:
L1 = closest_in_vocab(letters[0], LETTERS)
L2 = closest_in_vocab(letters[1], LETTERS)
out_tokens = [verb, L1, direction, L2, day, number]
return " ".join(out_tokens)