#!/usr/bin/env python3
"""
Image score analysis script
Analyzes images in a specified directory, calculates CLIP score, HPSv2 score and PickScore
"""
import os
import json
import glob
import numpy as np
import torch
from PIL import Image
from pathlib import Path
import argparse
from tqdm import tqdm
import statistics
import logging
# Import scorers
from flow_grpo.clip_scorer import ClipScorer
from flow_grpo.hpsv2_scorer import HPSv2Scorer
from flow_grpo.pickscore_scorer import PickScoreScorer
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class ImageScoreAnalyzer:
def __init__(self, device="cuda", dtype=torch.float32):
"""Initialize scorers"""
self.device = device
self.dtype = dtype
logger.info("Initializing scorers...")
self.clip_scorer = ClipScorer(device=device)
self.hpsv2_scorer = HPSv2Scorer(dtype=dtype, device=device)
self.pickscore_scorer = PickScoreScorer(device=device, dtype=dtype)
logger.info("Scorers initialization complete")
def load_captions(self, caption_file_path):
"""Load captions from JSON file"""
logger.info(f"Loading captions file: {caption_file_path}")
with open(caption_file_path, 'r', encoding='utf-8') as f:
captions_data = json.load(f)
# Create mapping from validation_n to caption
caption_map = {}
for item in captions_data:
image_path = item['image']
caption = item['caption']
# Extract validation_id
filename = os.path.basename(image_path)
validation_id = filename.replace('.png', '').replace('.jpg', '')
caption_map[validation_id] = caption
logger.info(f"Loaded {len(caption_map)} captions")
return caption_map
def load_images_from_directory(self, directory_path):
"""Load all PNG images from directory"""
png_files = glob.glob(os.path.join(directory_path, "*.png"))
png_files.sort() # Ensure consistent order
images = []
for img_path in png_files:
try:
img = Image.open(img_path).convert('RGB')
images.append(np.array(img))
except Exception as e:
logger.warning(f"Could not load image {img_path}: {e}")
continue
return images, png_files
def preprocess_images(self, images):
"""Preprocess images into tensor format"""
if not images:
return None
images_array = np.array(images)
images_array = images_array.transpose(0, 3, 1, 2) # NHWC -> NCHW
images_tensor = torch.tensor(images_array, dtype=torch.uint8) / 255.0
return images_tensor
def calculate_scores(self, images_tensor, caption):
"""Calculate three types of scores"""
batch_size = images_tensor.shape[0]
captions = [caption] * batch_size
# Calculate CLIP score
clip_scores = self.clip_scorer(images_tensor, captions).cpu().numpy()
# Calculate HPSv2 score
hpsv2_scores = self.hpsv2_scorer(images_tensor, captions).cpu().numpy()
# Calculate PickScore (note different parameter order)
pil_images = [Image.fromarray((images_tensor[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
for i in range(batch_size)]
pickscore_scores = self.pickscore_scorer(captions, pil_images).cpu().numpy()
return {
'clip_scores': clip_scores.tolist(),
'hpsv2_scores': hpsv2_scores.tolist(),
'pickscore_scores': pickscore_scores.tolist()
}
def calculate_statistics(self, scores):
"""Calculate statistics"""
stats = {}
for score_type, score_list in scores.items():
if score_list:
stats[score_type] = {
'mean': float(np.mean(score_list)),
'std': float(np.std(score_list)),
'min': float(np.min(score_list)),
'max': float(np.max(score_list)),
'median': float(np.median(score_list))
}
else:
stats[score_type] = None
return stats
def analyze_directory(self, data_dir, caption_map, output_file):
"""Analyze the entire data directory"""
results = {}
overall_stats = {
'clip_scores': [],
'hpsv2_scores': [],
'pickscore_scores': []
}
# Get all subdirectories
subdirs = [d for d in os.listdir(data_dir)
if os.path.isdir(os.path.join(data_dir, d)) and d.startswith('validation_')]
subdirs.sort()
logger.info(f"Found {len(subdirs)} data directories")
for subdir in tqdm(subdirs, desc="Processing data directories"):
subdir_path = os.path.join(data_dir, subdir)
# Get corresponding caption
if subdir not in caption_map:
logger.warning(f"No caption found for {subdir}, skipping")
continue
caption = caption_map[subdir]
# Load images
images, image_paths = self.load_images_from_directory(subdir_path)
if len(images) == 0:
logger.warning(f"No images found in directory {subdir}")
continue
logger.info(f"Processing {subdir}: Found {len(images)} images")
# Preprocess images
images_tensor = self.preprocess_images(images)
# Calculate scores
scores = self.calculate_scores(images_tensor, caption)
# Calculate statistics
stats = self.calculate_statistics(scores)
# Save results
results[subdir] = {
'caption': caption,
'num_images': len(images),
'image_paths': [os.path.basename(p) for p in image_paths],
'scores': scores,
'statistics': stats
}
# Accumulate overall statistics
overall_stats['clip_scores'].extend(scores['clip_scores'])
overall_stats['hpsv2_scores'].extend(scores['hpsv2_scores'])
overall_stats['pickscore_scores'].extend(scores['pickscore_scores'])
# Calculate overall statistics
overall_statistics = self.calculate_statistics(overall_stats)
# Save results
final_results = {
'overall_statistics': overall_statistics,
'total_data_entries': len(results),
'total_images': sum(len(data['scores']['clip_scores']) for data in results.values()),
'detailed_results': results
}
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(final_results, f, indent=2, ensure_ascii=False)
logger.info(f"Results saved to {output_file}")
return final_results
def main():
parser = argparse.ArgumentParser(description='Image score analysis tool')
parser.add_argument('--data_dir', type=str,
default='/data/phd/yaozhengjian/zjYao_Exprs/BLIP-3o-next/Eval/RL_Val/All',
help='Data directory path')
parser.add_argument('--caption_file', type=str,
default='/data/zgq/yaozhengjian/Datasets/FFHQ_val/CelebA_HQ/captions_lq.json',
help='Captions file path')
parser.add_argument('--output_file', type=str, default='image_scores_analysis.json',
help='Output file path')
parser.add_argument('--device', type=str, default='cuda', help='Computing device')
args = parser.parse_args()
# Check if paths exist
if not os.path.exists(args.data_dir):
logger.error(f"Data directory does not exist: {args.data_dir}")
return
if not os.path.exists(args.caption_file):
logger.error(f"Caption file does not exist: {args.caption_file}")
return
# Initialize analyzer
analyzer = ImageScoreAnalyzer(device=args.device)
# Load captions
caption_map = analyzer.load_captions(args.caption_file)
# Run analysis
results = analyzer.analyze_directory(args.data_dir, caption_map, args.output_file)
# Print overall statistics
logger.info("=== Overall Statistics ===")
for score_type, stats in results['overall_statistics'].items():
if stats:
logger.info(f"{score_type}: Average={stats['mean']:.4f}, StdDev={stats['std']:.4f}")
if __name__ == "__main__":
main()