AI-SPEAK / LipReadingApp / postprocess.py
postprocess.py
Raw
"""Postprocessing for decoded lip-reading commands.

Extracted from calculate_cer.py: token snapping with Levenshtein distance to
closed vocabularies (verb / direction / day / number) + optional letter slots.
"""

from __future__ import annotations

import numpy as np

# ---- closed vocabularies (keep in sync with training/eval) ----
VERBS = ["potvrdi", "odustani", "obriši", "pošaqi", "daqe", "početak", "kraj"]

LETTERS = [
    "a", "b", "c", "č", "ć", "d", "x", "đ", "e", "f",
    "g", "h", "i", "j", "k", "l", "q", "m", "n", "w",
    "o", "p", "r", "s", "š", "t", "u", "v", "z", "ž"
]

DIRECTIONS = ["napred", "nazad", "levo", "desno", "gore", "dole"]
DAYS = ["ponedeqak", "utorak", "sreda", "četvrtak", "petak", "subota", "nedeqa"]
NUMBERS = ["nula", "jedan", "dva", "tri", "četiri", "pet", "šest", "sedam", "osam", "devet"]


def levenshtein(a: str, b: str) -> int:
    """Classic Levenshtein distance (edit distance)."""
    a, b = a.strip(), b.strip()
    if a == b:
        return 0
    if len(a) == 0:
        return len(b)
    if len(b) == 0:
        return len(a)

    dp = np.zeros((len(a) + 1, len(b) + 1), dtype=np.int32)
    dp[:, 0] = np.arange(len(a) + 1)
    dp[0, :] = np.arange(len(b) + 1)

    for i in range(1, len(a) + 1):
        for j in range(1, len(b) + 1):
            cost = 0 if a[i - 1] == b[j - 1] else 1
            dp[i, j] = min(
                dp[i - 1, j] + 1,        # deletion
                dp[i, j - 1] + 1,        # insertion
                dp[i - 1, j - 1] + cost  # substitution
            )
    return int(dp[-1, -1])


def closest_in_vocab(token: str, vocab: list[str]) -> str:
    token = token.strip()
    if not token:
        return token
    best = vocab[0]
    best_d = 1e9
    for v in vocab:
        d = levenshtein(token, v)
        if d < best_d:
            best_d = d
            best = v
    return best


def pick_best_for_vocab(tokens: list[str], used: set[int], vocab: list[str]):
    """Pick the token (by index) whose best match in vocab has smallest distance."""
    best_idx = None
    best_word = vocab[0]
    best_d = 1e9
    any_candidate = False

    for i, tok in enumerate(tokens):
        if i in used:
            continue
        any_candidate = True
        tok = tok.strip()
        if not tok:
            continue

        local_best = vocab[0]
        local_best_d = 1e9
        for v in vocab:
            d = levenshtein(tok, v)
            if d < local_best_d:
                local_best_d = d
                local_best = v

        if local_best_d < best_d:
            best_d = local_best_d
            best_word = local_best
            best_idx = i

    if not any_candidate:
        return None, vocab[0]
    return best_idx, best_word


def postprocess_command(pred: str) -> str:
    """Snap a raw decoded string into a structured command."""
    tokens = pred.strip().split()
    if not tokens:
        return pred

    letters = []
    non_letters = []
    for tok in tokens:
        if len(tok) == 1:
            letters.append(tok)
        else:
            non_letters.append(tok)

    if not non_letters:
        return pred

    used: set[int] = set()

    verb_idx, verb = pick_best_for_vocab(non_letters, used, VERBS)
    if verb_idx is not None:
        used.add(verb_idx)

    dir_idx, direction = pick_best_for_vocab(non_letters, used, DIRECTIONS)
    if dir_idx is not None:
        used.add(dir_idx)

    day_idx, day = pick_best_for_vocab(non_letters, used, DAYS)
    if day_idx is not None:
        used.add(day_idx)

    num_idx, number = pick_best_for_vocab(non_letters, used, NUMBERS)
    if num_idx is not None:
        used.add(num_idx)

    if len(letters) == 0:
        out_tokens = [verb, direction, day, number]
    elif len(letters) == 1:
        L1 = closest_in_vocab(letters[0], LETTERS)
        out_tokens = [verb, L1, direction, day, number]
    else:
        L1 = closest_in_vocab(letters[0], LETTERS)
        L2 = closest_in_vocab(letters[1], LETTERS)
        out_tokens = [verb, L1, direction, L2, day, number]

    return " ".join(out_tokens)