import os
import random
from typing import List, Optional
import numpy as np
from PyQt5 import QtCore, QtGui, QtWidgets
import cv2
import mediapipe as mp
import tensorflow as tf
from camera_worker import CameraWorker, _extract_lip_roi_gray, _choose_prescale_range, _sample_scale
from inference import InferenceRunner
from textgen import generate_text
from config import RECORD_SECONDS
def _qimage_from_gray_u8(arr_u8: np.ndarray) -> QtGui.QImage:
"""Make a safe QImage from a 2D uint8 array (owns its data)."""
assert arr_u8.ndim == 2 and arr_u8.dtype == np.uint8
h, w = arr_u8.shape
arr_u8 = np.ascontiguousarray(arr_u8)
qimg = QtGui.QImage(arr_u8.data, w, h, w, QtGui.QImage.Format_Grayscale8)
return qimg.copy()
def _to_gray_u8_for_display(frame) -> Optional[np.ndarray]:
"""Convert possibly-float / TF / (H,W,1) frames to uint8 grayscale for UI preview."""
if frame is None:
return None
if isinstance(frame, tf.Tensor):
frame = frame.numpy()
arr = np.asarray(frame)
# (H,W,1) -> (H,W)
if arr.ndim == 3 and arr.shape[-1] == 1:
arr = np.squeeze(arr, axis=-1)
# (H,W,3) -> gray
if arr.ndim == 3 and arr.shape[-1] == 3:
arr = cv2.cvtColor(arr, cv2.COLOR_BGR2GRAY)
if arr.ndim != 2:
return None
if arr.dtype != np.uint8:
x = arr.astype(np.float32)
# robust normalization for display
lo, hi = np.percentile(x, [1.0, 99.0])
if hi - lo < 1e-6:
lo, hi = float(x.min()), float(x.max())
if hi - lo < 1e-6:
return np.zeros_like(x, dtype=np.uint8)
x = np.clip((x - lo) / (hi - lo), 0.0, 1.0) * 255.0
arr = x.astype(np.uint8)
return arr
class InferenceWorker(QtCore.QObject):
result_ready = QtCore.pyqtSignal(str)
status = QtCore.pyqtSignal(str)
def __init__(self, runner: InferenceRunner):
super().__init__()
self.runner = runner
@QtCore.pyqtSlot(object)
def run(self, frames_gray_list):
self.status.emit("Running inference…")
try:
result = self.runner.run(frames_gray_list)
except Exception as e:
result = f"[Inference error] {e}"
self.result_ready.emit(result)
class VideoFileExtractor(QtCore.QObject):
done = QtCore.pyqtSignal(object) # emits List[np.ndarray] (grayscale lip frames)
status = QtCore.pyqtSignal(str)
@QtCore.pyqtSlot(str)
def run(self, path: str):
if not os.path.exists(path):
self.status.emit(f"File not found: {path}")
self.done.emit([])
return
self.status.emit(f"Loading video: {path}")
cap = cv2.VideoCapture(path)
if not cap.isOpened():
self.status.emit("Could not open video.")
self.done.emit([])
return
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
profile = "1080x1920" if (h == 1920 and w == 1080) else "other"
self.status.emit(f"Detected input resolution: {w}x{h} -> profile={profile}")
scale_range = _choose_prescale_range(w, h)
video_prescale = _sample_scale(scale_range)
self.status.emit(f"Video prescale fixed at {video_prescale:.3f}")
face_mesh = mp.solutions.face_mesh.FaceMesh(
static_image_mode=False,
max_num_faces=1,
refine_landmarks=True,
min_detection_confidence=0.5,
min_tracking_confidence=0.5
)
fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0)
# If ~100fps -> skip=4; if ~25fps -> skip=1; generalize by rounding fps/25
skip = max(1, int(round(fps / 25.0))) if fps > 0 else 1
self.status.emit(f"Detected FPS={fps:.2f} -> reading every {skip} frame(s)")
frames: List[np.ndarray] = []
counter = 0
try:
while True:
ok, frame = cap.read()
if not ok:
break
if (counter % skip) != 0:
counter += 1
continue
counter += 1
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
res = face_mesh.process(rgb)
if res.multi_face_landmarks:
fl = res.multi_face_landmarks[0]
lip_gray = _extract_lip_roi_gray(frame, fl, prescale=video_prescale)
if lip_gray is not None:
frames.append(lip_gray)
else:
print("Skipping None")
finally:
face_mesh.close()
cap.release()
self.status.emit(f"Video parsed. Lip frames: {len(frames)}")
self.done.emit((frames, profile))
class MainWindow(QtWidgets.QMainWindow):
start_inference = QtCore.pyqtSignal(object)
def __init__(self):
super().__init__()
self.setWindowTitle("Lip-reading GUI")
self.current_prompt = ""
# =========================================
# Left: camera preview (top) + buttons (bottom)
# =========================================
self.video_label = QtWidgets.QLabel()
self.video_label.setMinimumSize(640, 360)
self.video_label.setAlignment(QtCore.Qt.AlignCenter)
self.video_label.setStyleSheet("background-color: #222; color: #eee;")
self.btn_infer_video = QtWidgets.QPushButton("INFER FROM MP4…")
self.btn_record = QtWidgets.QPushButton(f"RECORD FROM WEBCAM & INFER ({RECORD_SECONDS:.1f}s)")
btn_style = """
QPushButton {
background-color: #2e2e2e;
color: white;
font-weight: bold;
font-size: 13px;
padding: 10px;
border-radius: 8px;
border: 1px solid #555;
}
QPushButton:hover {
background-color: #3a3a3a;
}
QPushButton:pressed {
background-color: #1f1f1f;
}
"""
self.btn_infer_video.setStyleSheet(btn_style)
self.btn_record.setStyleSheet(btn_style)
# Buttons container (will be placed in the bottom-left grid cell)
buttons_box = QtWidgets.QWidget()
buttons_layout = QtWidgets.QVBoxLayout(buttons_box)
buttons_layout.setContentsMargins(0, 0, 0, 0)
buttons_layout.setSpacing(10)
buttons_layout.addWidget(self.btn_infer_video)
buttons_layout.addWidget(self.btn_record)
# =========================================
# Right (top): REC + prompt + decoded (bottom aligned with camera preview)
# =========================================
self.rec_badge = QtWidgets.QLabel("● REC")
self.rec_badge.setAlignment(QtCore.Qt.AlignCenter)
self.rec_badge.setFixedWidth(90)
self.rec_badge.setStyleSheet("""
QLabel {
background: #b00020;
color: white;
font-weight: bold;
padding: 6px 10px;
border-radius: 10px;
}
""")
self.rec_badge.hide()
rec_row = QtWidgets.QHBoxLayout()
rec_row.setContentsMargins(0, 0, 0, 0)
rec_row.addWidget(self.rec_badge)
rec_row.addStretch(1)
self.prompt_title = QtWidgets.QLabel("Text to read (during recording)")
self.prompt_label = QtWidgets.QLabel("—")
self.prompt_label.setWordWrap(True)
self.prompt_label.setAlignment(QtCore.Qt.AlignLeft | QtCore.Qt.AlignTop)
self.prompt_label.setMinimumHeight(90)
self.prompt_label.setStyleSheet("""
QLabel {
background: #111;
border: 1px solid #444;
padding: 10px;
font-size: 20px;
color: #eee;
border-radius: 6px;
}
""")
self.decoded_text = QtWidgets.QTextEdit()
self.decoded_text.setReadOnly(True)
f = self.decoded_text.font()
f.setPointSize(f.pointSize() + 1) # or setPointSize(12/13/14) explicitly
self.decoded_text.setFont(f)
right_top = QtWidgets.QWidget()
right_top_layout = QtWidgets.QVBoxLayout(right_top)
right_top_layout.setContentsMargins(0, 0, 0, 0)
right_top_layout.setSpacing(10)
right_top_layout.addLayout(rec_row)
right_top_layout.addWidget(self.prompt_title)
right_top_layout.addWidget(self.prompt_label)
self.decoded_title = QtWidgets.QLabel("Decoded text")
right_top_layout.addWidget(self.decoded_title)
right_top_layout.addWidget(self.decoded_text, 1) # EXPANDS; bottom aligns with camera preview
# =========================================
# Right (bottom): lip ROI preview centered under decoded box
# =========================================
self.lip_preview = QtWidgets.QLabel()
self.lip_preview.setFixedSize(200, 100)
self.lip_preview.setStyleSheet("background-color: black;")
self.lip_preview.setAlignment(QtCore.Qt.AlignCenter)
lip_title = QtWidgets.QLabel("Lip ROI preview")
lip_title.setAlignment(QtCore.Qt.AlignCenter)
label_style = "font-size: 14px; font-weight: 600;"
self.prompt_title.setStyleSheet(label_style)
self.decoded_title.setStyleSheet(label_style)
lip_title.setStyleSheet(label_style)
lip_row = QtWidgets.QHBoxLayout()
lip_row.setContentsMargins(0, 0, 0, 0)
lip_row.addStretch(1)
lip_row.addWidget(self.lip_preview)
lip_row.addStretch(1)
right_bottom = QtWidgets.QWidget()
right_bottom_layout = QtWidgets.QVBoxLayout(right_bottom)
right_bottom_layout.setContentsMargins(0, 0, 0, 0)
right_bottom_layout.setSpacing(6)
right_bottom_layout.addWidget(lip_title)
right_bottom_layout.addLayout(lip_row)
# =========================================
# Main grid layout:
# Row 0: camera preview (L) | prompt+decoded (R)
# Row 1: buttons (L) | lip preview (R)
#
# This guarantees:
# - decoded box bottom aligns with camera preview bottom (same row 0)
# - buttons bottom aligns with lip preview bottom (same row 1)
# - lip ROI preview centered under decoded box (right column)
# =========================================
central = QtWidgets.QWidget()
grid = QtWidgets.QGridLayout(central)
grid.setContentsMargins(12, 12, 12, 12)
grid.setHorizontalSpacing(18)
grid.setVerticalSpacing(12)
grid.addWidget(self.video_label, 0, 0)
grid.addWidget(right_top, 0, 1)
# Bottom row widgets aligned to the bottom of the row
grid.addWidget(buttons_box, 1, 0, alignment=QtCore.Qt.AlignBottom)
grid.addWidget(right_bottom, 1, 1, alignment=QtCore.Qt.AlignBottom)
# Make the TOP row take the available height; bottom row stays compact
grid.setRowStretch(0, 1)
grid.setRowStretch(1, 0)
# Right column expands
grid.setColumnStretch(0, 0)
grid.setColumnStretch(1, 1)
self.setCentralWidget(central)
# ==========================
# Threads / workers
# ==========================
self.cam_thread = QtCore.QThread()
self.cam_worker = CameraWorker()
self.cam_worker.moveToThread(self.cam_thread)
self.cam_worker.frame_ready.connect(self.update_frame)
self.cam_worker.recording_done.connect(self.on_recording_done)
self.cam_worker.status.connect(self.set_status)
# If you added recording_state_changed in camera_worker.py, connect it
if hasattr(self.cam_worker, "recording_state_changed"):
try:
self.cam_worker.recording_state_changed.connect(self.on_recording_state_changed)
except Exception:
pass
self.cam_thread.start()
self.infer_runner = InferenceRunner()
self.infer_thread = QtCore.QThread()
self.infer_worker = InferenceWorker(self.infer_runner)
self.infer_worker.moveToThread(self.infer_thread)
self.start_inference.connect(self.infer_worker.run, QtCore.Qt.QueuedConnection)
self.infer_worker.result_ready.connect(self.on_inference_result)
self.infer_worker.status.connect(self.set_status)
self.infer_thread.start()
self.video_thread = QtCore.QThread()
self.video_worker = VideoFileExtractor()
self.video_worker.moveToThread(self.video_thread)
self.video_worker.done.connect(self._on_video_extracted)
self.video_worker.status.connect(self.set_status)
self.video_thread.start()
# ==========================
# Actions
# ==========================
self.btn_record.clicked.connect(self.on_record_clicked)
self.btn_infer_video.clicked.connect(self.on_infer_from_video_clicked)
self.statusBar().showMessage("Ready")
# Wrap content better on startup
self.adjustSize()
self.resize(self.sizeHint())
# --------------------------
# UI helpers
# --------------------------
def _set_record_button_style(self, is_recording: bool):
if is_recording:
self.btn_record.setText("Recording…")
self.btn_record.setStyleSheet("""
QPushButton {
background: #b00020;
color: white;
font-weight: bold;
padding: 10px;
border-radius: 8px;
}
QPushButton:disabled { background: #6a0012; }
""")
else:
self.btn_record.setText(f"Record from webcam & infer ({RECORD_SECONDS:.1f}s)")
self.btn_record.setStyleSheet("")
def _set_recording_ui(self, is_recording: bool):
self.rec_badge.setVisible(is_recording)
self._set_record_button_style(is_recording)
# Prevent switching modes mid-recording (simpler UX)
self.btn_infer_video.setEnabled(not is_recording)
self.btn_record.setEnabled(not is_recording)
# --------------------------
# Slots
# --------------------------
@QtCore.pyqtSlot(QtGui.QImage)
def update_frame(self, qimg):
pix = QtGui.QPixmap.fromImage(qimg)
self.video_label.setPixmap(
pix.scaled(
self.video_label.width(),
self.video_label.height(),
QtCore.Qt.KeepAspectRatio,
QtCore.Qt.SmoothTransformation,
)
)
@QtCore.pyqtSlot()
def on_record_clicked(self):
# Generate and show the prompt immediately so the user can start reading right away
self.current_prompt = generate_text()
self.prompt_label.setText(self.current_prompt)
# Clear previous decoded output for clarity
self.decoded_text.clear()
# If camera worker doesn't emit recording_state_changed, still update UI immediately
if not hasattr(self.cam_worker, "recording_state_changed"):
self._set_recording_ui(True)
self.cam_worker.start_recording()
@QtCore.pyqtSlot(bool)
def on_recording_state_changed(self, is_recording: bool):
self._set_recording_ui(is_recording)
@QtCore.pyqtSlot(object)
def on_recording_done(self, payload):
# payload is (frames_gray_list, (w, h))
frames_gray_list, (w, h) = payload
# Stop recording UI (even if no signal exists)
self._set_recording_ui(False)
if not frames_gray_list:
self.set_status("No lip ROI frames captured from webcam.")
return
# Choose normalization profile by ORIGINAL source resolution
profile = "1080x1920" if (w == 1080 and h == 1920) else "other"
frame = random.choice(frames_gray_list)
img_u8 = _to_gray_u8_for_display(frame)
if img_u8 is not None:
qimg = _qimage_from_gray_u8(img_u8)
pix = QtGui.QPixmap.fromImage(qimg).scaled(
self.lip_preview.width(),
self.lip_preview.height(),
QtCore.Qt.KeepAspectRatio,
QtCore.Qt.SmoothTransformation,
)
self.lip_preview.setPixmap(pix)
self.prompt_label.setText("")
# Pass (frames, profile) into inference
self.start_inference.emit((frames_gray_list, profile))
@QtCore.pyqtSlot()
def on_infer_from_video_clicked(self):
path, _ = QtWidgets.QFileDialog.getOpenFileName(
self,
"Choose a video file",
"",
"Video files (*.mp4 *.mov *.avi *.mkv);;All files (*.*)"
)
if not path:
return
self.decoded_text.clear()
QtCore.QMetaObject.invokeMethod(
self.video_worker,
"run",
QtCore.Qt.QueuedConnection,
QtCore.Q_ARG(str, path),
)
self.set_status(f"Parsing video… {os.path.basename(path)}")
@QtCore.pyqtSlot(object)
def _on_video_extracted(self, payload):
# payload is (lip_frames, profile)
lip_frames, profile = payload
if not lip_frames:
self.set_status("No lip ROI frames extracted from video.")
return
frame = random.choice(lip_frames)
img_u8 = _to_gray_u8_for_display(frame)
if img_u8 is not None:
qimg = _qimage_from_gray_u8(img_u8)
pix = QtGui.QPixmap.fromImage(qimg).scaled(
self.lip_preview.width(),
self.lip_preview.height(),
QtCore.Qt.KeepAspectRatio,
QtCore.Qt.SmoothTransformation,
)
self.lip_preview.setPixmap(pix)
# Pass (frames, profile) into inference
self.start_inference.emit((lip_frames, profile))
@QtCore.pyqtSlot(str)
def on_inference_result(self, text):
self.decoded_text.setPlainText(text)
def set_status(self, msg: str):
self.statusBar().showMessage(msg, 6000)
def closeEvent(self, event):
try:
self.cam_worker.close()
except Exception:
pass
try:
self.cam_thread.quit()
self.cam_thread.wait(1000)
except Exception:
pass
try:
self.infer_thread.quit()
self.infer_thread.wait(1000)
except Exception:
pass
try:
self.video_thread.quit()
self.video_thread.wait(1000)
except Exception:
pass
super().closeEvent(event)