AI-SPEAK / LipReadingApp / inference.py
inference.py
Raw
# inference.py
import os
from typing import List, Optional

import numpy as np
import tensorflow as tf

from config import MODEL_PATH

import json
from config import NORM_STATS_1080x1920_JSON, NORM_STATS_OTHER_JSON
from postprocess import postprocess_command

_STATS_CACHE = {}  # path -> (mean, std)

def _load_stats_json(path):
    if not path:
        return None
    if path in _STATS_CACHE:
        return _STATS_CACHE[path]
    try:
        with open(path, "r", encoding="utf-8") as f:
            d = json.load(f)
        mean = float(d["mean"])
        std = float(d["std"])
        if abs(std) < 1e-6:
            std = 1.0
        _STATS_CACHE[path] = (mean, std)
        return _STATS_CACHE[path]
    except Exception:
        return None


def _normalize_with_profile(frames_gray_list, profile_key: str):
    """
    frames_gray_list: list of (H,W) frames AFTER CLAHE + resize (100x50).
    profile_key: "1080x1920" or "other"
    """
    arr = np.stack(frames_gray_list, axis=0).astype(np.float32)  # (T,H,W)

    stats = None
    if profile_key == "1080x1920":
        #print("Profile: 1080x1920")
        stats = _load_stats_json(NORM_STATS_1080x1920_JSON)
    else:
        stats = _load_stats_json(NORM_STATS_OTHER_JSON)

    if stats is not None:
        mean, std = stats
        return list(((arr - mean) / std))

    # fallback: clip-level z-score
    m = float(arr.mean())
    s = float(arr.std())
    if s < 1e-6:
        s = 1.0
    return list(((arr - m) / s))


# -----------------------------
# Load model (inference-only)
# -----------------------------
def _friendly_load_model(path: str):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Model path does not exist: {path}")
    if not os.path.exists(os.path.join(path, "saved_model.pb")):
        raise FileNotFoundError("Missing saved_model.pb in model folder.")
    return tf.keras.models.load_model(path, compile=False)

try:
    model = _friendly_load_model(MODEL_PATH)
except Exception as e:
    print(e)
    model = None


def vocabulary():
    vocab = [x for x in ["a", "b", "c", "č", "ć", "d", "x", "đ", "e", "f",  "g", "h", "i", "j", "k", "l", "q",
                         "m", "n", "w",  "o", "p", "r", "s", "š", "t", "u", "v", "z", "ž", " "]]
    char_to_num = tf.keras.layers.StringLookup(vocabulary=vocab, oov_token="")
    num_to_char = tf.keras.layers.StringLookup(
        vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True
    )
    return char_to_num, num_to_char


char_to_num, num_to_char = vocabulary()

# -----------------------------
# Time-length handling
# -----------------------------
def _expected_T_from_model(m) -> Optional[int]:
    try:
        ish = m.input_shape
        if isinstance(ish, (list, tuple)):
            ish = ish[0]
        return int(ish[1]) if ish[1] is not None else None
    except Exception:
        return None

# Keep your existing fixed length (matches your training setup)
EXPECTED_T = 150

# -----------------------------
# Normalization
# -----------------------------
_NORM_CACHE = {"mean": None, "std": None, "loaded": False}


def _maybe_load_external_norm():
    """Lazy-load optional external normalization stats."""
    if _NORM_CACHE["loaded"]:
        return
    _NORM_CACHE["loaded"] = True

    def _load(path):
        if not path:
            return None
        if not os.path.exists(path):
            return None
        obj = np.load(path)
        if isinstance(obj, np.lib.npyio.NpzFile):
            keys = list(obj.keys())
            if not keys:
                return None
            return obj[keys[0]]
        return obj

    _NORM_CACHE["mean"] = _load(NORM_MEAN_PATH)
    _NORM_CACHE["std"] = _load(NORM_STD_PATH)


def _normalize_clip(frames_gray: List[np.ndarray], already_normalized: bool) -> List[np.ndarray]:
    """
    Normalization precedence:
      1) already_normalized=True -> do nothing (only cast to float32)
      2) external mean/std -> (x-mean)/std
      3) fallback -> clip-level z-score over T*H*W
    """
    arr = np.stack(frames_gray, axis=0).astype(np.float32)  # (T,H,W)

    if already_normalized:
        return [arr[t] for t in range(arr.shape[0])]

    _maybe_load_external_norm()
    mean = _NORM_CACHE["mean"]
    std = _NORM_CACHE["std"]

    if mean is not None and std is not None:
        mean = np.asarray(mean, dtype=np.float32)
        std = np.asarray(std, dtype=np.float32)
        std = np.where(np.abs(std) < 1e-6, 1.0, std)
        arr = (arr - mean) / std
        return [arr[t] for t in range(arr.shape[0])]

    m = float(arr.mean())
    s = float(arr.std())
    if s < 1e-6:
        s = 1.0
    arr = (arr - m) / s
    return [arr[t] for t in range(arr.shape[0])]


def _pad_or_trim_after_norm(frames: List[np.ndarray], target_T: int, mode: str = "center") -> List[np.ndarray]:
    T = len(frames)
    if target_T is None or T == target_T:
        return frames
    if T > target_T:
        if mode == "center":
            extra = T - target_T
            left = extra // 2
            right = extra - left
            return frames[left:T - right]
        return frames[:target_T]

    if T == 0:
        raise ValueError("No frames captured.")
    h, w = frames[0].shape
    pad = [np.zeros((h, w), dtype=np.float32) for _ in range(target_T - T)]
    return frames + pad


def _frames_to_tensor_no_norm(frames_gray: List[np.ndarray]) -> Optional[tf.Tensor]:
    if len(frames_gray) == 0:
        return None
    arr = np.stack(frames_gray, axis=0).astype(np.float32)  # (T,H,W)
    arr = arr[..., None]   # (T,H,W,1)
    arr = arr[None, ...]   # (1,T,H,W,1)
    return tf.convert_to_tensor(arr, dtype=tf.float32)


class InferenceRunner:
    def __init__(self):
        self._busy = False

    def run(self, payload) -> str:
        if model is None or num_to_char is None:
            return "Model or num_to_char not set. Please plug them in."
        if payload is None:
            return "[No input]"

        # Supported payload formats:
        #  1) frames_list
        #  2) (frames_list, profile_key: str)          e.g. "1080x1920" / "other"
        #  3) (frames_list, already_normalized: bool)  backward-compatible
        profile_key = "other"
        already_norm = False

        if isinstance(payload, tuple) and len(payload) == 2:
            print("Payload 2")
            frames_gray_list, second = payload
            if isinstance(second, str):
                profile_key = second
                already_norm = False
            else:
                already_norm = bool(second)
                profile_key = "other"
        else:
            frames_gray_list = payload

        if not frames_gray_list:
            return "[No frames captured]"

        frames_norm = _normalize_with_profile(frames_gray_list, profile_key)	
        orig_T = len(frames_norm)

        if EXPECTED_T is not None and len(frames_norm) != EXPECTED_T:
            orig_T = len(frames_norm)
            frames_norm = _pad_or_trim_after_norm(frames_norm, EXPECTED_T, mode="center")
            print(f"[Inference] Adjusted T from {orig_T} -> {EXPECTED_T}")

        X = _frames_to_tensor_no_norm(frames_norm)
        if X is None:
            return "[Tensor creation failed]"

        yhat = model.predict(X)
        Tprime = yhat.shape[1]
        true_len = min(orig_T, EXPECTED_T)
        decoded_sparse = tf.keras.backend.ctc_decode(
            yhat,
            input_length=[true_len],   # B=1
            greedy=False
        )[0][0].numpy()
		
		
        #print("yhat shape:", yhat.shape)
        #print("yhat min/max:", yhat.min(), yhat.max())
        #print("sum over V (first frame):", yhat[0,0,:].sum())

        lines = []
        for i in range(yhat.shape[0]):
            ids = decoded_sparse[i]
            if isinstance(num_to_char, tf.keras.layers.StringLookup):
                ids = ids[ids != -1]
                s = tf.strings.reduce_join(num_to_char(ids)).numpy().decode("utf-8")
            else:
                s = "".join(num_to_char(int(x)) for x in ids if int(x) >= 0)
            lines.append(postprocess_command(s))
        return "\n".join(lines)