from PIL import Image import io import numpy as np import torch from collections import defaultdict def jpeg_incompressibility(): def _fn(images, prompts, metadata): if isinstance(images, torch.Tensor): images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC images = [Image.fromarray(image) for image in images] buffers = [io.BytesIO() for _ in images] for image, buffer in zip(images, buffers): image.save(buffer, format="JPEG", quality=95) sizes = [buffer.tell() / 1000 for buffer in buffers] return np.array(sizes), {} return _fn def jpeg_compressibility(): jpeg_fn = jpeg_incompressibility() def _fn(images, prompts, metadata): rew, meta = jpeg_fn(images, prompts, metadata) return -rew / 500, meta return _fn def aesthetic_score(device): from flow_grpo.aesthetic_scorer import AestheticScorer scorer = AestheticScorer(dtype=torch.float32, device=device) def _fn(images, prompts, metadata): if isinstance(images, torch.Tensor): images = (images * 255).round().clamp(0, 255).to(torch.uint8) else: images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW images = torch.tensor(images, dtype=torch.uint8) scores = scorer(images) return scores, {} return _fn def clip_score(device): from flow_grpo.clip_scorer import ClipScorer scorer = ClipScorer(device=device) def _fn(images, prompts, metadata): if not isinstance(images, torch.Tensor): images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW images = torch.tensor(images, dtype=torch.uint8) / 255.0 scores = scorer(images, prompts) return scores, {} return _fn def hpsv2_score(device): from flow_grpo.hpsv2_scorer import HPSv2Scorer scorer = HPSv2Scorer(dtype=torch.float32, device=device) def _fn(images, prompts, metadata): if not isinstance(images, torch.Tensor): images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW images = torch.tensor(images, dtype=torch.uint8) / 255.0 scores = scorer(images, prompts) return scores, {} return _fn def pickscore_score(device): from flow_grpo.pickscore_scorer import PickScoreScorer scorer = PickScoreScorer(dtype=torch.float32, device=device) def _fn(images, prompts, metadata): if isinstance(images, torch.Tensor): images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC images = [Image.fromarray(image) for image in images] scores = scorer(prompts, images) return scores, {} return _fn def imagereward_score(device): from flow_grpo.imagereward_scorer import ImageRewardScorer scorer = ImageRewardScorer(dtype=torch.float32, device=device) def _fn(images, prompts, metadata): if isinstance(images, torch.Tensor): images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC images = [Image.fromarray(image) for image in images] prompts = [prompt for prompt in prompts] scores = scorer(prompts, images) return scores, {} return _fn def geneval_score(device): from flow_grpo.gen_eval import load_geneval batch_size = 64 compute_geneval = load_geneval(device) def _fn(images, prompts, metadatas, only_strict): del prompts if isinstance(images, torch.Tensor): images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC images_batched = np.array_split(images, np.ceil(len(images) / batch_size)) metadatas_batched = np.array_split(metadatas, np.ceil(len(metadatas) / batch_size)) all_scores = [] all_rewards = [] all_strict_rewards = [] all_group_strict_rewards = [] all_group_rewards = [] for image_batch, metadata_batched in zip(images_batched, metadatas_batched): pil_images = [Image.fromarray(image) for image in image_batch] data = { "images": pil_images, "metadatas": list(metadata_batched), "only_strict": only_strict, } scores, rewards, strict_rewards, group_rewards, group_strict_rewards = compute_geneval(**data) all_scores += scores all_rewards += rewards all_strict_rewards += strict_rewards all_group_strict_rewards.append(group_strict_rewards) all_group_rewards.append(group_rewards) all_group_strict_rewards_dict = defaultdict(list) all_group_rewards_dict = defaultdict(list) for current_dict in all_group_strict_rewards: for key, value in current_dict.items(): all_group_strict_rewards_dict[key].extend(value) all_group_strict_rewards_dict = dict(all_group_strict_rewards_dict) for current_dict in all_group_rewards: for key, value in current_dict.items(): all_group_rewards_dict[key].extend(value) all_group_rewards_dict = dict(all_group_rewards_dict) return all_scores, all_rewards, all_strict_rewards, all_group_rewards_dict, all_group_strict_rewards_dict return _fn def ocr_score(device): from flow_grpo.ocr import OcrScorer scorer = OcrScorer() def _fn(images, prompts, metadata): if isinstance(images, torch.Tensor): images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC scores = scorer(images, prompts) # change tensor to list return scores, {} return _fn def unifiedreward_score_sglang(device): import asyncio from openai import AsyncOpenAI import base64 from io import BytesIO import re def pil_image_to_base64(image): buffered = BytesIO() image.save(buffered, format="PNG") encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8") base64_qwen = f"data:image;base64,{encoded_image_text}" return base64_qwen def _extract_scores(text_outputs): scores = [] pattern = r"Final Score:\s*([1-5](?:\.\d+)?)" for text in text_outputs: match = re.search(pattern, text) if match: try: scores.append(float(match.group(1))) except ValueError: scores.append(0.0) else: scores.append(0.0) return scores client = AsyncOpenAI(base_url="http://127.0.0.1:17140/v1", api_key="flowgrpo") async def evaluate_image(prompt, image): question = f"\nYou are given a text caption and a generated image based on that caption. Your task is to evaluate this image based on two key criteria:\n1. Alignment with the Caption: Assess how well this image aligns with the provided caption. Consider the accuracy of depicted objects, their relationships, and attributes as described in the caption.\n2. Overall Image Quality: Examine the visual quality of this image, including clarity, detail preservation, color accuracy, and overall aesthetic appeal.\nBased on the above criteria, assign a score from 1 to 5 after 'Final Score:'.\nYour task is provided as follows:\nText Caption: [{prompt}]" images_base64 = pil_image_to_base64(image) response = await client.chat.completions.create( model="UnifiedReward-7b-v1.5", messages=[ { "role": "user", "content": [ { "type": "image_url", "image_url": {"url": images_base64}, }, { "type": "text", "text": question, }, ], }, ], temperature=0, ) return response.choices[0].message.content async def evaluate_batch_image(images, prompts): tasks = [evaluate_image(prompt, img) for prompt, img in zip(prompts, images)] results = await asyncio.gather(*tasks) return results def _fn(images, prompts, metadata): # Handle Tensor type conversion if isinstance(images, torch.Tensor): images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC # Convert to PIL Image and resize images = [Image.fromarray(image).resize((512, 512)) for image in images] # Execute asynchronous batch evaluation text_outputs = asyncio.run(evaluate_batch_image(images, prompts)) score = _extract_scores(text_outputs) score = [sc / 5.0 for sc in score] return score, {} return _fn def multi_score(device, score_dict): score_functions = { "ocr": ocr_score, "imagereward": imagereward_score, "pickscore": pickscore_score, "aesthetic": aesthetic_score, "jpeg_compressibility": jpeg_compressibility, "unifiedreward": unifiedreward_score_sglang, "geneval": geneval_score, "clipscore": clip_score, "hpsv2": hpsv2_score, } score_fns = {} for score_name, weight in score_dict.items(): score_fns[score_name] = ( score_functions[score_name](device) if "device" in score_functions[score_name].__code__.co_varnames else score_functions[score_name]() ) # only_strict is only for geneval. During training, only the strict reward is needed, and non-strict rewards don't need to be computed, reducing reward calculation time. def _fn(images, prompts, metadata, only_strict=True): total_scores = [] score_details = {} for score_name, weight in score_dict.items(): if score_name == "geneval": scores, rewards, strict_rewards, group_rewards, group_strict_rewards = score_fns[score_name]( images, prompts, metadata, only_strict ) score_details["accuracy"] = rewards score_details["strict_accuracy"] = strict_rewards for key, value in group_strict_rewards.items(): score_details[f"{key}_strict_accuracy"] = value for key, value in group_rewards.items(): score_details[f"{key}_accuracy"] = value else: scores, rewards = score_fns[score_name](images, prompts, metadata) score_details[score_name] = scores weighted_scores = [weight * score for score in scores] if not total_scores: total_scores = weighted_scores else: total_scores = [total + weighted for total, weighted in zip(total_scores, weighted_scores)] score_details["avg"] = total_scores return score_details, {} return _fn def main(): import torchvision.transforms as transforms image_paths = [ "test_cases/nasa.jpg", ] transform = transforms.Compose( [ transforms.ToTensor(), # Convert to tensor ] ) images = torch.stack([transform(Image.open(image_path).convert("RGB")) for image_path in image_paths]) prompts = [ 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist', ] metadata = {} # Example metadata score_dict = {"unifiedreward": 1.0} # Initialize the multi_score function with a device and score_dict device = torch.device("cuda" if torch.cuda.is_available() else "cpu") scoring_fn = multi_score(device, score_dict) # Get the scores scores, _ = scoring_fn(images, prompts, metadata) # Print the scores print("Scores:", scores) if __name__ == "__main__": main()