from paddleocr import PaddleOCR import torch import numpy as np from Levenshtein import distance from typing import List, Union from PIL import Image class OcrScorer: def __init__(self, use_gpu: bool = False): """ OCR reward calculator :param use_gpu: Whether to use GPU acceleration for PaddleOCR """ self.ocr = PaddleOCR( use_angle_cls=False, lang="en", use_gpu=use_gpu, show_log=False # Disable unnecessary log output ) @torch.no_grad() def __call__(self, images: Union[List[Image.Image], List[np.ndarray]], prompts: List[str]) -> torch.Tensor: """ Calculate OCR reward :param images: List of input images (PIL or numpy format) :param prompts: Corresponding target text list :return: Reward tensor (CPU) """ prompts = [prompt.split('"')[1] for prompt in prompts] rewards = [] # Ensure input lengths are consistent assert len(images) == len(prompts), "Images and prompts must have the same length" for img, prompt in zip(images, prompts): # Convert image format if isinstance(img, Image.Image): img = np.array(img) try: # OCR recognition result = self.ocr.ocr(img, cls=False) # Extract recognized text (handle possible multi-line results) recognized_text = ( "".join([res[1][0] if res[1][1] > 0 else "" for res in result[0]]) if result[0] else "" ) recognized_text = recognized_text.replace(" ", "").lower() prompt = prompt.replace(" ", "").lower() if prompt in recognized_text: dist = 0 else: dist = distance(recognized_text, prompt) # Recognized many unrelated characters, only add one character penalty if dist > len(prompt): dist = len(prompt) except Exception as e: # Error handling (e.g., OCR parsing failure) print(f"OCR processing failed: {str(e)}") dist = len(prompt) # Maximum penalty reward = 1 - dist / (len(prompt)) rewards.append(reward) return rewards if __name__ == "__main__": example_image_path = "test_cases/hello world.jpg" example_image = Image.open(example_image_path) example_prompt = 'New York Skyline with "Hello World" written with fireworks on the sky' # Instantiate scorer scorer = OcrScorer(use_gpu=False) # Call scorer and print result reward = scorer([example_image], [example_prompt]) print(f"OCR Reward: {reward}")