Pref-Restoration / DiffusionNFT / image_score_analyzer.py
image_score_analyzer.py
Raw
#!/usr/bin/env python3
"""
Image score analysis script
Analyzes images in a specified directory, calculates CLIP score, HPSv2 score and PickScore
"""

import os
import json
import glob
import numpy as np
import torch
from PIL import Image
from pathlib import Path
import argparse
from tqdm import tqdm
import statistics
import logging

# Import scorers
from flow_grpo.clip_scorer import ClipScorer
from flow_grpo.hpsv2_scorer import HPSv2Scorer
from flow_grpo.pickscore_scorer import PickScoreScorer

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class ImageScoreAnalyzer:
    def __init__(self, device="cuda", dtype=torch.float32):
        """Initialize scorers"""
        self.device = device
        self.dtype = dtype
        
        logger.info("Initializing scorers...")
        self.clip_scorer = ClipScorer(device=device)
        self.hpsv2_scorer = HPSv2Scorer(dtype=dtype, device=device)
        self.pickscore_scorer = PickScoreScorer(device=device, dtype=dtype)
        logger.info("Scorers initialization complete")
        
    def load_captions(self, caption_file_path):
        """Load captions from JSON file"""
        logger.info(f"Loading captions file: {caption_file_path}")
        with open(caption_file_path, 'r', encoding='utf-8') as f:
            captions_data = json.load(f)
        
        # Create mapping from validation_n to caption
        caption_map = {}
        for item in captions_data:
            image_path = item['image']
            caption = item['caption']
            # Extract validation_id
            filename = os.path.basename(image_path)
            validation_id = filename.replace('.png', '').replace('.jpg', '')
            caption_map[validation_id] = caption
            
        logger.info(f"Loaded {len(caption_map)} captions")
        return caption_map
    
    def load_images_from_directory(self, directory_path):
        """Load all PNG images from directory"""
        png_files = glob.glob(os.path.join(directory_path, "*.png"))
        png_files.sort()  # Ensure consistent order
        
        images = []
        for img_path in png_files:
            try:
                img = Image.open(img_path).convert('RGB')
                images.append(np.array(img))
            except Exception as e:
                logger.warning(f"Could not load image {img_path}: {e}")
                continue
                
        return images, png_files
    
    def preprocess_images(self, images):
        """Preprocess images into tensor format"""
        if not images:
            return None
            
        images_array = np.array(images)
        images_array = images_array.transpose(0, 3, 1, 2)  # NHWC -> NCHW
        images_tensor = torch.tensor(images_array, dtype=torch.uint8) / 255.0
        return images_tensor
    
    def calculate_scores(self, images_tensor, caption):
        """Calculate three types of scores"""
        batch_size = images_tensor.shape[0]
        captions = [caption] * batch_size
        
        # Calculate CLIP score
        clip_scores = self.clip_scorer(images_tensor, captions).cpu().numpy()
        
        # Calculate HPSv2 score
        hpsv2_scores = self.hpsv2_scorer(images_tensor, captions).cpu().numpy()
        
        # Calculate PickScore (note different parameter order)
        pil_images = [Image.fromarray((images_tensor[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) 
                     for i in range(batch_size)]
        pickscore_scores = self.pickscore_scorer(captions, pil_images).cpu().numpy()
        
        return {
            'clip_scores': clip_scores.tolist(),
            'hpsv2_scores': hpsv2_scores.tolist(),
            'pickscore_scores': pickscore_scores.tolist()
        }
    
    def calculate_statistics(self, scores):
        """Calculate statistics"""
        stats = {}
        for score_type, score_list in scores.items():
            if score_list:
                stats[score_type] = {
                    'mean': float(np.mean(score_list)),
                    'std': float(np.std(score_list)),
                    'min': float(np.min(score_list)),
                    'max': float(np.max(score_list)),
                    'median': float(np.median(score_list))
                }
            else:
                stats[score_type] = None
        return stats
    
    def analyze_directory(self, data_dir, caption_map, output_file):
        """Analyze the entire data directory"""
        results = {}
        overall_stats = {
            'clip_scores': [],
            'hpsv2_scores': [],
            'pickscore_scores': []
        }
        
        # Get all subdirectories
        subdirs = [d for d in os.listdir(data_dir) 
                  if os.path.isdir(os.path.join(data_dir, d)) and d.startswith('validation_')]
        subdirs.sort()
        
        logger.info(f"Found {len(subdirs)} data directories")
        
        for subdir in tqdm(subdirs, desc="Processing data directories"):
            subdir_path = os.path.join(data_dir, subdir)
            
            # Get corresponding caption
            if subdir not in caption_map:
                logger.warning(f"No caption found for {subdir}, skipping")
                continue
                
            caption = caption_map[subdir]
            
            # Load images
            images, image_paths = self.load_images_from_directory(subdir_path)
            
            if len(images) == 0:
                logger.warning(f"No images found in directory {subdir}")
                continue
                
            logger.info(f"Processing {subdir}: Found {len(images)} images")
            
            # Preprocess images
            images_tensor = self.preprocess_images(images)
            
            # Calculate scores
            scores = self.calculate_scores(images_tensor, caption)
            
            # Calculate statistics
            stats = self.calculate_statistics(scores)
            
            # Save results
            results[subdir] = {
                'caption': caption,
                'num_images': len(images),
                'image_paths': [os.path.basename(p) for p in image_paths],
                'scores': scores,
                'statistics': stats
            }
            
            # Accumulate overall statistics
            overall_stats['clip_scores'].extend(scores['clip_scores'])
            overall_stats['hpsv2_scores'].extend(scores['hpsv2_scores'])
            overall_stats['pickscore_scores'].extend(scores['pickscore_scores'])
        
        # Calculate overall statistics
        overall_statistics = self.calculate_statistics(overall_stats)
        
        # Save results
        final_results = {
            'overall_statistics': overall_statistics,
            'total_data_entries': len(results),
            'total_images': sum(len(data['scores']['clip_scores']) for data in results.values()),
            'detailed_results': results
        }
        
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(final_results, f, indent=2, ensure_ascii=False)
            
        logger.info(f"Results saved to {output_file}")
        return final_results
 
 
def main():
    parser = argparse.ArgumentParser(description='Image score analysis tool')
    parser.add_argument('--data_dir', type=str, 
                       default='/data/phd/yaozhengjian/zjYao_Exprs/BLIP-3o-next/Eval/RL_Val/All',
                       help='Data directory path')
    parser.add_argument('--caption_file', type=str,
                       default='/data/zgq/yaozhengjian/Datasets/FFHQ_val/CelebA_HQ/captions_lq.json',
                       help='Captions file path')
    parser.add_argument('--output_file', type=str, default='image_scores_analysis.json',
                       help='Output file path')
    parser.add_argument('--device', type=str, default='cuda', help='Computing device')
    
    args = parser.parse_args()
    
    # Check if paths exist
    if not os.path.exists(args.data_dir):
        logger.error(f"Data directory does not exist: {args.data_dir}")
        return
        
    if not os.path.exists(args.caption_file):
        logger.error(f"Caption file does not exist: {args.caption_file}")
        return
    
    # Initialize analyzer
    analyzer = ImageScoreAnalyzer(device=args.device)
    
    # Load captions
    caption_map = analyzer.load_captions(args.caption_file)
    
    # Run analysis
    results = analyzer.analyze_directory(args.data_dir, caption_map, args.output_file)
    
    # Print overall statistics
    logger.info("=== Overall Statistics ===")
    for score_type, stats in results['overall_statistics'].items():
        if stats:
            logger.info(f"{score_type}: Average={stats['mean']:.4f}, StdDev={stats['std']:.4f}")


if __name__ == "__main__":
    main()