AI-SPEAK / LipReadingApp / main_window.py
main_window.py
Raw
import os
import random
from typing import List, Optional

import numpy as np
from PyQt5 import QtCore, QtGui, QtWidgets
import cv2
import mediapipe as mp
import tensorflow as tf

from camera_worker import CameraWorker, _extract_lip_roi_gray, _choose_prescale_range, _sample_scale
from inference import InferenceRunner
from textgen import generate_text
from config import RECORD_SECONDS


def _qimage_from_gray_u8(arr_u8: np.ndarray) -> QtGui.QImage:
    """Make a safe QImage from a 2D uint8 array (owns its data)."""
    assert arr_u8.ndim == 2 and arr_u8.dtype == np.uint8
    h, w = arr_u8.shape
    arr_u8 = np.ascontiguousarray(arr_u8)
    qimg = QtGui.QImage(arr_u8.data, w, h, w, QtGui.QImage.Format_Grayscale8)
    return qimg.copy()


def _to_gray_u8_for_display(frame) -> Optional[np.ndarray]:
    """Convert possibly-float / TF / (H,W,1) frames to uint8 grayscale for UI preview."""
    if frame is None:
        return None

    if isinstance(frame, tf.Tensor):
        frame = frame.numpy()

    arr = np.asarray(frame)

    # (H,W,1) -> (H,W)
    if arr.ndim == 3 and arr.shape[-1] == 1:
        arr = np.squeeze(arr, axis=-1)

    # (H,W,3) -> gray
    if arr.ndim == 3 and arr.shape[-1] == 3:
        arr = cv2.cvtColor(arr, cv2.COLOR_BGR2GRAY)

    if arr.ndim != 2:
        return None

    if arr.dtype != np.uint8:
        x = arr.astype(np.float32)

        # robust normalization for display
        lo, hi = np.percentile(x, [1.0, 99.0])
        if hi - lo < 1e-6:
            lo, hi = float(x.min()), float(x.max())
        if hi - lo < 1e-6:
            return np.zeros_like(x, dtype=np.uint8)

        x = np.clip((x - lo) / (hi - lo), 0.0, 1.0) * 255.0
        arr = x.astype(np.uint8)

    return arr


class InferenceWorker(QtCore.QObject):
    result_ready = QtCore.pyqtSignal(str)
    status = QtCore.pyqtSignal(str)

    def __init__(self, runner: InferenceRunner):
        super().__init__()
        self.runner = runner

    @QtCore.pyqtSlot(object)
    def run(self, frames_gray_list):
        self.status.emit("Running inference…")
        try:
            result = self.runner.run(frames_gray_list)
        except Exception as e:
            result = f"[Inference error] {e}"
        self.result_ready.emit(result)


class VideoFileExtractor(QtCore.QObject):
    done = QtCore.pyqtSignal(object)   # emits List[np.ndarray] (grayscale lip frames)
    status = QtCore.pyqtSignal(str)

    @QtCore.pyqtSlot(str)
    def run(self, path: str):
        if not os.path.exists(path):
            self.status.emit(f"File not found: {path}")
            self.done.emit([])
            return

        self.status.emit(f"Loading video: {path}")
        cap = cv2.VideoCapture(path)
        if not cap.isOpened():
            self.status.emit("Could not open video.")
            self.done.emit([])
            return
			
        w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
        h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
        profile = "1080x1920" if (h == 1920 and w == 1080) else "other"
        self.status.emit(f"Detected input resolution: {w}x{h} -> profile={profile}")
		
        scale_range = _choose_prescale_range(w, h)
        video_prescale = _sample_scale(scale_range)
        self.status.emit(f"Video prescale fixed at {video_prescale:.3f}")

        face_mesh = mp.solutions.face_mesh.FaceMesh(
            static_image_mode=False,
            max_num_faces=1,
            refine_landmarks=True,
            min_detection_confidence=0.5,
            min_tracking_confidence=0.5
        )

        fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0)
        # If ~100fps -> skip=4; if ~25fps -> skip=1; generalize by rounding fps/25
        skip = max(1, int(round(fps / 25.0))) if fps > 0 else 1
        self.status.emit(f"Detected FPS={fps:.2f} -> reading every {skip} frame(s)")

        frames: List[np.ndarray] = []
        counter = 0
        try:
            while True:
                ok, frame = cap.read()
                if not ok:
                    break

                if (counter % skip) != 0:
                    counter += 1
                    continue
                counter += 1

                rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                res = face_mesh.process(rgb)
                if res.multi_face_landmarks:
                    fl = res.multi_face_landmarks[0]
                    lip_gray = _extract_lip_roi_gray(frame, fl, prescale=video_prescale)
                    if lip_gray is not None:
                        frames.append(lip_gray)
                    else:
                        print("Skipping None")					
        finally:
            face_mesh.close()
            cap.release()

        self.status.emit(f"Video parsed. Lip frames: {len(frames)}")
        self.done.emit((frames, profile))


class MainWindow(QtWidgets.QMainWindow):
    start_inference = QtCore.pyqtSignal(object)

    def __init__(self):
        super().__init__()
        self.setWindowTitle("Lip-reading GUI")

        self.current_prompt = ""

        # =========================================
        # Left: camera preview (top) + buttons (bottom)
        # =========================================
        self.video_label = QtWidgets.QLabel()
        self.video_label.setMinimumSize(640, 360)
        self.video_label.setAlignment(QtCore.Qt.AlignCenter)
        self.video_label.setStyleSheet("background-color: #222; color: #eee;")

        self.btn_infer_video = QtWidgets.QPushButton("INFER FROM MP4…")
        self.btn_record = QtWidgets.QPushButton(f"RECORD FROM WEBCAM & INFER ({RECORD_SECONDS:.1f}s)")
		
        btn_style = """
        QPushButton {
            background-color: #2e2e2e;
            color: white;
            font-weight: bold;
            font-size: 13px;
            padding: 10px;
            border-radius: 8px;
            border: 1px solid #555;
        }

        QPushButton:hover {
            background-color: #3a3a3a;
        }

        QPushButton:pressed {
            background-color: #1f1f1f;
        }
        """
        self.btn_infer_video.setStyleSheet(btn_style)
        self.btn_record.setStyleSheet(btn_style)

        # Buttons container (will be placed in the bottom-left grid cell)
        buttons_box = QtWidgets.QWidget()
        buttons_layout = QtWidgets.QVBoxLayout(buttons_box)
        buttons_layout.setContentsMargins(0, 0, 0, 0)
        buttons_layout.setSpacing(10)
        buttons_layout.addWidget(self.btn_infer_video)
        buttons_layout.addWidget(self.btn_record)

        # =========================================
        # Right (top): REC + prompt + decoded (bottom aligned with camera preview)
        # =========================================
        self.rec_badge = QtWidgets.QLabel("● REC")
        self.rec_badge.setAlignment(QtCore.Qt.AlignCenter)
        self.rec_badge.setFixedWidth(90)
        self.rec_badge.setStyleSheet("""
            QLabel {
                background: #b00020;
                color: white;
                font-weight: bold;
                padding: 6px 10px;
                border-radius: 10px;
            }
        """)
        self.rec_badge.hide()

        rec_row = QtWidgets.QHBoxLayout()
        rec_row.setContentsMargins(0, 0, 0, 0)
        rec_row.addWidget(self.rec_badge)
        rec_row.addStretch(1)

        self.prompt_title = QtWidgets.QLabel("Text to read (during recording)")
        self.prompt_label = QtWidgets.QLabel("")
        self.prompt_label.setWordWrap(True)
        self.prompt_label.setAlignment(QtCore.Qt.AlignLeft | QtCore.Qt.AlignTop)
        self.prompt_label.setMinimumHeight(90)
        self.prompt_label.setStyleSheet("""
            QLabel {
                background: #111;
                border: 1px solid #444;
                padding: 10px;
                font-size: 20px;
                color: #eee;
                border-radius: 6px;
            }
        """)

        self.decoded_text = QtWidgets.QTextEdit()
        self.decoded_text.setReadOnly(True)
        f = self.decoded_text.font()
        f.setPointSize(f.pointSize() + 1)  # or setPointSize(12/13/14) explicitly
        self.decoded_text.setFont(f)

        right_top = QtWidgets.QWidget()
        right_top_layout = QtWidgets.QVBoxLayout(right_top)
        right_top_layout.setContentsMargins(0, 0, 0, 0)
        right_top_layout.setSpacing(10)
        right_top_layout.addLayout(rec_row)
        right_top_layout.addWidget(self.prompt_title)
        right_top_layout.addWidget(self.prompt_label)

        self.decoded_title = QtWidgets.QLabel("Decoded text")
        right_top_layout.addWidget(self.decoded_title)
        right_top_layout.addWidget(self.decoded_text, 1)  # EXPANDS; bottom aligns with camera preview

        # =========================================
        # Right (bottom): lip ROI preview centered under decoded box
        # =========================================
        self.lip_preview = QtWidgets.QLabel()
        self.lip_preview.setFixedSize(200, 100)
        self.lip_preview.setStyleSheet("background-color: black;")
        self.lip_preview.setAlignment(QtCore.Qt.AlignCenter)

        lip_title = QtWidgets.QLabel("Lip ROI preview")
        lip_title.setAlignment(QtCore.Qt.AlignCenter)
		
        label_style = "font-size: 14px; font-weight: 600;"
        self.prompt_title.setStyleSheet(label_style)
        self.decoded_title.setStyleSheet(label_style)
        lip_title.setStyleSheet(label_style)

        lip_row = QtWidgets.QHBoxLayout()
        lip_row.setContentsMargins(0, 0, 0, 0)
        lip_row.addStretch(1)
        lip_row.addWidget(self.lip_preview)
        lip_row.addStretch(1)

        right_bottom = QtWidgets.QWidget()
        right_bottom_layout = QtWidgets.QVBoxLayout(right_bottom)
        right_bottom_layout.setContentsMargins(0, 0, 0, 0)
        right_bottom_layout.setSpacing(6)
        right_bottom_layout.addWidget(lip_title)
        right_bottom_layout.addLayout(lip_row)

        # =========================================
        # Main grid layout:
        # Row 0: camera preview (L) | prompt+decoded (R)
        # Row 1: buttons (L)       | lip preview (R)
        #
        # This guarantees:
        # - decoded box bottom aligns with camera preview bottom (same row 0)
        # - buttons bottom aligns with lip preview bottom (same row 1)
        # - lip ROI preview centered under decoded box (right column)
        # =========================================
        central = QtWidgets.QWidget()
        grid = QtWidgets.QGridLayout(central)
        grid.setContentsMargins(12, 12, 12, 12)
        grid.setHorizontalSpacing(18)
        grid.setVerticalSpacing(12)

        grid.addWidget(self.video_label, 0, 0)
        grid.addWidget(right_top, 0, 1)

        # Bottom row widgets aligned to the bottom of the row
        grid.addWidget(buttons_box, 1, 0, alignment=QtCore.Qt.AlignBottom)
        grid.addWidget(right_bottom, 1, 1, alignment=QtCore.Qt.AlignBottom)

        # Make the TOP row take the available height; bottom row stays compact
        grid.setRowStretch(0, 1)
        grid.setRowStretch(1, 0)

        # Right column expands
        grid.setColumnStretch(0, 0)
        grid.setColumnStretch(1, 1)

        self.setCentralWidget(central)

        # ==========================
        # Threads / workers
        # ==========================
        self.cam_thread = QtCore.QThread()
        self.cam_worker = CameraWorker()
        self.cam_worker.moveToThread(self.cam_thread)
        self.cam_worker.frame_ready.connect(self.update_frame)
        self.cam_worker.recording_done.connect(self.on_recording_done)
        self.cam_worker.status.connect(self.set_status)

        # If you added recording_state_changed in camera_worker.py, connect it
        if hasattr(self.cam_worker, "recording_state_changed"):
            try:
                self.cam_worker.recording_state_changed.connect(self.on_recording_state_changed)
            except Exception:
                pass

        self.cam_thread.start()

        self.infer_runner = InferenceRunner()
        self.infer_thread = QtCore.QThread()
        self.infer_worker = InferenceWorker(self.infer_runner)
        self.infer_worker.moveToThread(self.infer_thread)
        self.start_inference.connect(self.infer_worker.run, QtCore.Qt.QueuedConnection)
        self.infer_worker.result_ready.connect(self.on_inference_result)
        self.infer_worker.status.connect(self.set_status)
        self.infer_thread.start()

        self.video_thread = QtCore.QThread()
        self.video_worker = VideoFileExtractor()
        self.video_worker.moveToThread(self.video_thread)
        self.video_worker.done.connect(self._on_video_extracted)
        self.video_worker.status.connect(self.set_status)
        self.video_thread.start()

        # ==========================
        # Actions
        # ==========================
        self.btn_record.clicked.connect(self.on_record_clicked)
        self.btn_infer_video.clicked.connect(self.on_infer_from_video_clicked)

        self.statusBar().showMessage("Ready")

        # Wrap content better on startup
        self.adjustSize()
        self.resize(self.sizeHint())

    # --------------------------
    # UI helpers
    # --------------------------
    def _set_record_button_style(self, is_recording: bool):
        if is_recording:
            self.btn_record.setText("Recording…")
            self.btn_record.setStyleSheet("""
                QPushButton {
                    background: #b00020;
                    color: white;
                    font-weight: bold;
                    padding: 10px;
                    border-radius: 8px;
                }
                QPushButton:disabled { background: #6a0012; }
            """)
        else:
            self.btn_record.setText(f"Record from webcam & infer ({RECORD_SECONDS:.1f}s)")
            self.btn_record.setStyleSheet("")

    def _set_recording_ui(self, is_recording: bool):
        self.rec_badge.setVisible(is_recording)
        self._set_record_button_style(is_recording)

        # Prevent switching modes mid-recording (simpler UX)
        self.btn_infer_video.setEnabled(not is_recording)
        self.btn_record.setEnabled(not is_recording)

    # --------------------------
    # Slots
    # --------------------------
    @QtCore.pyqtSlot(QtGui.QImage)
    def update_frame(self, qimg):
        pix = QtGui.QPixmap.fromImage(qimg)
        self.video_label.setPixmap(
            pix.scaled(
                self.video_label.width(),
                self.video_label.height(),
                QtCore.Qt.KeepAspectRatio,
                QtCore.Qt.SmoothTransformation,
            )
        )

    @QtCore.pyqtSlot()
    def on_record_clicked(self):
        # Generate and show the prompt immediately so the user can start reading right away
        self.current_prompt = generate_text()
        self.prompt_label.setText(self.current_prompt)

        # Clear previous decoded output for clarity
        self.decoded_text.clear()

        # If camera worker doesn't emit recording_state_changed, still update UI immediately
        if not hasattr(self.cam_worker, "recording_state_changed"):
            self._set_recording_ui(True)

        self.cam_worker.start_recording()

    @QtCore.pyqtSlot(bool)
    def on_recording_state_changed(self, is_recording: bool):
        self._set_recording_ui(is_recording)

    @QtCore.pyqtSlot(object)
    def on_recording_done(self, payload):
        # payload is (frames_gray_list, (w, h))
        frames_gray_list, (w, h) = payload

        # Stop recording UI (even if no signal exists)
        self._set_recording_ui(False)

        if not frames_gray_list:
            self.set_status("No lip ROI frames captured from webcam.")
            return

        # Choose normalization profile by ORIGINAL source resolution
        profile = "1080x1920" if (w == 1080 and h == 1920) else "other"

        frame = random.choice(frames_gray_list)
        img_u8 = _to_gray_u8_for_display(frame)
        if img_u8 is not None:
            qimg = _qimage_from_gray_u8(img_u8)
            pix = QtGui.QPixmap.fromImage(qimg).scaled(
                self.lip_preview.width(),
                self.lip_preview.height(),
                QtCore.Qt.KeepAspectRatio,
                QtCore.Qt.SmoothTransformation,
            )
            self.lip_preview.setPixmap(pix)
        
        self.prompt_label.setText("")
		
        # Pass (frames, profile) into inference
        self.start_inference.emit((frames_gray_list, profile))


    @QtCore.pyqtSlot()
    def on_infer_from_video_clicked(self):
        path, _ = QtWidgets.QFileDialog.getOpenFileName(
            self,
            "Choose a video file",
            "",
            "Video files (*.mp4 *.mov *.avi *.mkv);;All files (*.*)"
        )
        if not path:
            return

        self.decoded_text.clear()

        QtCore.QMetaObject.invokeMethod(
            self.video_worker,
            "run",
            QtCore.Qt.QueuedConnection,
            QtCore.Q_ARG(str, path),
        )
        self.set_status(f"Parsing video… {os.path.basename(path)}")

    @QtCore.pyqtSlot(object)
    def _on_video_extracted(self, payload):
        # payload is (lip_frames, profile)
        lip_frames, profile = payload

        if not lip_frames:
            self.set_status("No lip ROI frames extracted from video.")
            return

        frame = random.choice(lip_frames)
        img_u8 = _to_gray_u8_for_display(frame)
        if img_u8 is not None:
            qimg = _qimage_from_gray_u8(img_u8)
            pix = QtGui.QPixmap.fromImage(qimg).scaled(
                self.lip_preview.width(),
                self.lip_preview.height(),
                QtCore.Qt.KeepAspectRatio,
                QtCore.Qt.SmoothTransformation,
            )
            self.lip_preview.setPixmap(pix)

        # Pass (frames, profile) into inference
        self.start_inference.emit((lip_frames, profile))


    @QtCore.pyqtSlot(str)
    def on_inference_result(self, text):
        self.decoded_text.setPlainText(text)

    def set_status(self, msg: str):
        self.statusBar().showMessage(msg, 6000)

    def closeEvent(self, event):
        try:
            self.cam_worker.close()
        except Exception:
            pass

        try:
            self.cam_thread.quit()
            self.cam_thread.wait(1000)
        except Exception:
            pass

        try:
            self.infer_thread.quit()
            self.infer_thread.wait(1000)
        except Exception:
            pass

        try:
            self.video_thread.quit()
            self.video_thread.wait(1000)
        except Exception:
            pass

        super().closeEvent(event)