#!/usr/bin/env python3 """ Image score analysis script Analyzes images in a specified directory, calculates CLIP score, HPSv2 score and PickScore """ import os import json import glob import numpy as np import torch from PIL import Image from pathlib import Path import argparse from tqdm import tqdm import statistics import logging # Import scorers from flow_grpo.clip_scorer import ClipScorer from flow_grpo.hpsv2_scorer import HPSv2Scorer from flow_grpo.pickscore_scorer import PickScoreScorer # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class ImageScoreAnalyzer: def __init__(self, device="cuda", dtype=torch.float32): """Initialize scorers""" self.device = device self.dtype = dtype logger.info("Initializing scorers...") self.clip_scorer = ClipScorer(device=device) self.hpsv2_scorer = HPSv2Scorer(dtype=dtype, device=device) self.pickscore_scorer = PickScoreScorer(device=device, dtype=dtype) logger.info("Scorers initialization complete") def load_captions(self, caption_file_path): """Load captions from JSON file""" logger.info(f"Loading captions file: {caption_file_path}") with open(caption_file_path, 'r', encoding='utf-8') as f: captions_data = json.load(f) # Create mapping from validation_n to caption caption_map = {} for item in captions_data: image_path = item['image'] caption = item['caption'] # Extract validation_id filename = os.path.basename(image_path) validation_id = filename.replace('.png', '').replace('.jpg', '') caption_map[validation_id] = caption logger.info(f"Loaded {len(caption_map)} captions") return caption_map def load_images_from_directory(self, directory_path): """Load all PNG images from directory""" png_files = glob.glob(os.path.join(directory_path, "*.png")) png_files.sort() # Ensure consistent order images = [] for img_path in png_files: try: img = Image.open(img_path).convert('RGB') images.append(np.array(img)) except Exception as e: logger.warning(f"Could not load image {img_path}: {e}") continue return images, png_files def preprocess_images(self, images): """Preprocess images into tensor format""" if not images: return None images_array = np.array(images) images_array = images_array.transpose(0, 3, 1, 2) # NHWC -> NCHW images_tensor = torch.tensor(images_array, dtype=torch.uint8) / 255.0 return images_tensor def calculate_scores(self, images_tensor, caption): """Calculate three types of scores""" batch_size = images_tensor.shape[0] captions = [caption] * batch_size # Calculate CLIP score clip_scores = self.clip_scorer(images_tensor, captions).cpu().numpy() # Calculate HPSv2 score hpsv2_scores = self.hpsv2_scorer(images_tensor, captions).cpu().numpy() # Calculate PickScore (note different parameter order) pil_images = [Image.fromarray((images_tensor[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) for i in range(batch_size)] pickscore_scores = self.pickscore_scorer(captions, pil_images).cpu().numpy() return { 'clip_scores': clip_scores.tolist(), 'hpsv2_scores': hpsv2_scores.tolist(), 'pickscore_scores': pickscore_scores.tolist() } def calculate_statistics(self, scores): """Calculate statistics""" stats = {} for score_type, score_list in scores.items(): if score_list: stats[score_type] = { 'mean': float(np.mean(score_list)), 'std': float(np.std(score_list)), 'min': float(np.min(score_list)), 'max': float(np.max(score_list)), 'median': float(np.median(score_list)) } else: stats[score_type] = None return stats def analyze_directory(self, data_dir, caption_map, output_file): """Analyze the entire data directory""" results = {} overall_stats = { 'clip_scores': [], 'hpsv2_scores': [], 'pickscore_scores': [] } # Get all subdirectories subdirs = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d)) and d.startswith('validation_')] subdirs.sort() logger.info(f"Found {len(subdirs)} data directories") for subdir in tqdm(subdirs, desc="Processing data directories"): subdir_path = os.path.join(data_dir, subdir) # Get corresponding caption if subdir not in caption_map: logger.warning(f"No caption found for {subdir}, skipping") continue caption = caption_map[subdir] # Load images images, image_paths = self.load_images_from_directory(subdir_path) if len(images) == 0: logger.warning(f"No images found in directory {subdir}") continue logger.info(f"Processing {subdir}: Found {len(images)} images") # Preprocess images images_tensor = self.preprocess_images(images) # Calculate scores scores = self.calculate_scores(images_tensor, caption) # Calculate statistics stats = self.calculate_statistics(scores) # Save results results[subdir] = { 'caption': caption, 'num_images': len(images), 'image_paths': [os.path.basename(p) for p in image_paths], 'scores': scores, 'statistics': stats } # Accumulate overall statistics overall_stats['clip_scores'].extend(scores['clip_scores']) overall_stats['hpsv2_scores'].extend(scores['hpsv2_scores']) overall_stats['pickscore_scores'].extend(scores['pickscore_scores']) # Calculate overall statistics overall_statistics = self.calculate_statistics(overall_stats) # Save results final_results = { 'overall_statistics': overall_statistics, 'total_data_entries': len(results), 'total_images': sum(len(data['scores']['clip_scores']) for data in results.values()), 'detailed_results': results } with open(output_file, 'w', encoding='utf-8') as f: json.dump(final_results, f, indent=2, ensure_ascii=False) logger.info(f"Results saved to {output_file}") return final_results def main(): parser = argparse.ArgumentParser(description='Image score analysis tool') parser.add_argument('--data_dir', type=str, default='/data/phd/yaozhengjian/zjYao_Exprs/BLIP-3o-next/Eval/RL_Val/All', help='Data directory path') parser.add_argument('--caption_file', type=str, default='/data/zgq/yaozhengjian/Datasets/FFHQ_val/CelebA_HQ/captions_lq.json', help='Captions file path') parser.add_argument('--output_file', type=str, default='image_scores_analysis.json', help='Output file path') parser.add_argument('--device', type=str, default='cuda', help='Computing device') args = parser.parse_args() # Check if paths exist if not os.path.exists(args.data_dir): logger.error(f"Data directory does not exist: {args.data_dir}") return if not os.path.exists(args.caption_file): logger.error(f"Caption file does not exist: {args.caption_file}") return # Initialize analyzer analyzer = ImageScoreAnalyzer(device=args.device) # Load captions caption_map = analyzer.load_captions(args.caption_file) # Run analysis results = analyzer.analyze_directory(args.data_dir, caption_map, args.output_file) # Print overall statistics logger.info("=== Overall Statistics ===") for score_type, stats in results['overall_statistics'].items(): if stats: logger.info(f"{score_type}: Average={stats['mean']:.4f}, StdDev={stats['std']:.4f}") if __name__ == "__main__": main()