# inference.py
import os
from typing import List, Optional
import numpy as np
import tensorflow as tf
from config import MODEL_PATH
import json
from config import NORM_STATS_1080x1920_JSON, NORM_STATS_OTHER_JSON
from postprocess import postprocess_command
_STATS_CACHE = {} # path -> (mean, std)
def _load_stats_json(path):
if not path:
return None
if path in _STATS_CACHE:
return _STATS_CACHE[path]
try:
with open(path, "r", encoding="utf-8") as f:
d = json.load(f)
mean = float(d["mean"])
std = float(d["std"])
if abs(std) < 1e-6:
std = 1.0
_STATS_CACHE[path] = (mean, std)
return _STATS_CACHE[path]
except Exception:
return None
def _normalize_with_profile(frames_gray_list, profile_key: str):
"""
frames_gray_list: list of (H,W) frames AFTER CLAHE + resize (100x50).
profile_key: "1080x1920" or "other"
"""
arr = np.stack(frames_gray_list, axis=0).astype(np.float32) # (T,H,W)
stats = None
if profile_key == "1080x1920":
#print("Profile: 1080x1920")
stats = _load_stats_json(NORM_STATS_1080x1920_JSON)
else:
stats = _load_stats_json(NORM_STATS_OTHER_JSON)
if stats is not None:
mean, std = stats
return list(((arr - mean) / std))
# fallback: clip-level z-score
m = float(arr.mean())
s = float(arr.std())
if s < 1e-6:
s = 1.0
return list(((arr - m) / s))
# -----------------------------
# Load model (inference-only)
# -----------------------------
def _friendly_load_model(path: str):
if not os.path.exists(path):
raise FileNotFoundError(f"Model path does not exist: {path}")
if not os.path.exists(os.path.join(path, "saved_model.pb")):
raise FileNotFoundError("Missing saved_model.pb in model folder.")
return tf.keras.models.load_model(path, compile=False)
try:
model = _friendly_load_model(MODEL_PATH)
except Exception as e:
print(e)
model = None
def vocabulary():
vocab = [x for x in ["a", "b", "c", "č", "ć", "d", "x", "đ", "e", "f", "g", "h", "i", "j", "k", "l", "q",
"m", "n", "w", "o", "p", "r", "s", "š", "t", "u", "v", "z", "ž", " "]]
char_to_num = tf.keras.layers.StringLookup(vocabulary=vocab, oov_token="")
num_to_char = tf.keras.layers.StringLookup(
vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True
)
return char_to_num, num_to_char
char_to_num, num_to_char = vocabulary()
# -----------------------------
# Time-length handling
# -----------------------------
def _expected_T_from_model(m) -> Optional[int]:
try:
ish = m.input_shape
if isinstance(ish, (list, tuple)):
ish = ish[0]
return int(ish[1]) if ish[1] is not None else None
except Exception:
return None
# Keep your existing fixed length (matches your training setup)
EXPECTED_T = 150
# -----------------------------
# Normalization
# -----------------------------
_NORM_CACHE = {"mean": None, "std": None, "loaded": False}
def _maybe_load_external_norm():
"""Lazy-load optional external normalization stats."""
if _NORM_CACHE["loaded"]:
return
_NORM_CACHE["loaded"] = True
def _load(path):
if not path:
return None
if not os.path.exists(path):
return None
obj = np.load(path)
if isinstance(obj, np.lib.npyio.NpzFile):
keys = list(obj.keys())
if not keys:
return None
return obj[keys[0]]
return obj
_NORM_CACHE["mean"] = _load(NORM_MEAN_PATH)
_NORM_CACHE["std"] = _load(NORM_STD_PATH)
def _normalize_clip(frames_gray: List[np.ndarray], already_normalized: bool) -> List[np.ndarray]:
"""
Normalization precedence:
1) already_normalized=True -> do nothing (only cast to float32)
2) external mean/std -> (x-mean)/std
3) fallback -> clip-level z-score over T*H*W
"""
arr = np.stack(frames_gray, axis=0).astype(np.float32) # (T,H,W)
if already_normalized:
return [arr[t] for t in range(arr.shape[0])]
_maybe_load_external_norm()
mean = _NORM_CACHE["mean"]
std = _NORM_CACHE["std"]
if mean is not None and std is not None:
mean = np.asarray(mean, dtype=np.float32)
std = np.asarray(std, dtype=np.float32)
std = np.where(np.abs(std) < 1e-6, 1.0, std)
arr = (arr - mean) / std
return [arr[t] for t in range(arr.shape[0])]
m = float(arr.mean())
s = float(arr.std())
if s < 1e-6:
s = 1.0
arr = (arr - m) / s
return [arr[t] for t in range(arr.shape[0])]
def _pad_or_trim_after_norm(frames: List[np.ndarray], target_T: int, mode: str = "center") -> List[np.ndarray]:
T = len(frames)
if target_T is None or T == target_T:
return frames
if T > target_T:
if mode == "center":
extra = T - target_T
left = extra // 2
right = extra - left
return frames[left:T - right]
return frames[:target_T]
if T == 0:
raise ValueError("No frames captured.")
h, w = frames[0].shape
pad = [np.zeros((h, w), dtype=np.float32) for _ in range(target_T - T)]
return frames + pad
def _frames_to_tensor_no_norm(frames_gray: List[np.ndarray]) -> Optional[tf.Tensor]:
if len(frames_gray) == 0:
return None
arr = np.stack(frames_gray, axis=0).astype(np.float32) # (T,H,W)
arr = arr[..., None] # (T,H,W,1)
arr = arr[None, ...] # (1,T,H,W,1)
return tf.convert_to_tensor(arr, dtype=tf.float32)
class InferenceRunner:
def __init__(self):
self._busy = False
def run(self, payload) -> str:
if model is None or num_to_char is None:
return "Model or num_to_char not set. Please plug them in."
if payload is None:
return "[No input]"
# Supported payload formats:
# 1) frames_list
# 2) (frames_list, profile_key: str) e.g. "1080x1920" / "other"
# 3) (frames_list, already_normalized: bool) backward-compatible
profile_key = "other"
already_norm = False
if isinstance(payload, tuple) and len(payload) == 2:
print("Payload 2")
frames_gray_list, second = payload
if isinstance(second, str):
profile_key = second
already_norm = False
else:
already_norm = bool(second)
profile_key = "other"
else:
frames_gray_list = payload
if not frames_gray_list:
return "[No frames captured]"
frames_norm = _normalize_with_profile(frames_gray_list, profile_key)
orig_T = len(frames_norm)
if EXPECTED_T is not None and len(frames_norm) != EXPECTED_T:
orig_T = len(frames_norm)
frames_norm = _pad_or_trim_after_norm(frames_norm, EXPECTED_T, mode="center")
print(f"[Inference] Adjusted T from {orig_T} -> {EXPECTED_T}")
X = _frames_to_tensor_no_norm(frames_norm)
if X is None:
return "[Tensor creation failed]"
yhat = model.predict(X)
Tprime = yhat.shape[1]
true_len = min(orig_T, EXPECTED_T)
decoded_sparse = tf.keras.backend.ctc_decode(
yhat,
input_length=[true_len], # B=1
greedy=False
)[0][0].numpy()
#print("yhat shape:", yhat.shape)
#print("yhat min/max:", yhat.min(), yhat.max())
#print("sum over V (first frame):", yhat[0,0,:].sum())
lines = []
for i in range(yhat.shape[0]):
ids = decoded_sparse[i]
if isinstance(num_to_char, tf.keras.layers.StringLookup):
ids = ids[ids != -1]
s = tf.strings.reduce_join(num_to_char(ids)).numpy().decode("utf-8")
else:
s = "".join(num_to_char(int(x)) for x in ids if int(x) >= 0)
lines.append(postprocess_command(s))
return "\n".join(lines)