Pref-Restoration / process_image_degradation.py
process_image_degradation.py
Raw
#!/usr/bin/env python3
"""
Image degradation processing script
Reads a JSON file containing image paths and descriptions, applies degradation processing to the images, and generates a JSONL output file
"""

import json
import jsonlines
import os
import sys
from pathlib import Path
from PIL import Image
import argparse
from tqdm import tqdm

# Add blip3o module to path
sys.path.append(os.path.join(os.path.dirname(__file__), 'blip3o'))

from blip3o.data.image_degradation import degrade_image, ImageDegradationConfig


def load_json_data(json_file_path):
    """
    Load JSON data file
    
    Args:
        json_file_path (str): Path to JSON file
    
    Returns:
        list: List containing image information
    """
    try:
        with open(json_file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        print(f"Successfully loaded {len(data)} items")
        return data
    except Exception as e:
        print(f"Failed to load JSON file: {e}")
        return []


def create_output_directory(output_dir):
    """
    Create output directory
    
    Args:
        output_dir (str): Output directory path
    """
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    print(f"Output directory created: {output_dir}")


def process_single_image(image_path, caption, output_dir, degradation_config, index):
    """
    Process a single image
    
    Args:
        image_path (str): Original image path
        caption (str): Image description
        output_dir (str): Output directory
        degradation_config (ImageDegradationConfig): Degradation configuration
        index (int): Image index
    
    Returns:
        dict or None: Result dictionary containing prompt, image, and requirement fields
    """
    try:
        # Check if original image exists
        if not os.path.exists(image_path):
            print(f"Warning: Image file does not exist - {image_path}")
            return None
        
        # Load original image
        img_gt = Image.open(image_path).convert('RGB')
        
        # Apply degradation to image
        img_degraded = degradation_config.degrade(img_gt)
        
        # Generate output filename
        original_filename = os.path.basename(image_path)
        name_without_ext = os.path.splitext(original_filename)[0]
        output_filename = f"{name_without_ext}.png"
        output_path = os.path.join(output_dir, output_filename)
        
        # Save degraded image
        img_degraded.save(output_path)
        
        # Simplify caption as prompt (can be adjusted as needed)
        prompt = caption.strip()
        
        # Return data in JSONL format
        return {
            "prompt": prompt,
            "image": output_path,
            "requirement": "Restore"
        }
        
    except Exception as e:
        print(f"Failed to process image {image_path}: {e}")
        return None


def main():
    """
    Main function
    """
    parser = argparse.ArgumentParser(description='Image degradation processing script')
    parser.add_argument('--input_json', type=str, 
                       default='/data/zgq/yaozhengjian/Datasets/FFHQ/caption/long_captions.json',
                       help='Input JSON file path')
    parser.add_argument('--output_dir', type=str, 
                       default='/data/zgq/yaozhengjian/Datasets/FFHQ/images_degraded',
                       help='Output directory for degraded images')
    parser.add_argument('--output_jsonl', type=str, 
                       default='degraded_images.jsonl',
                       help='Output JSONL file path')
    parser.add_argument('--max_images', type=int, default=None,
                       help='Maximum number of images to process (for testing)')
    
    args = parser.parse_args()
    
    print("Starting image degradation processing...")
    print(f"Input JSON file: {args.input_json}")
    print(f"Output directory: {args.output_dir}")
    print(f"Output JSONL file: {args.output_jsonl}")
    
    # Load JSON data
    data = load_json_data(args.input_json)
    if not data:
        print("No data to process, exiting program")
        return
    
    # Limit processing count (for testing)
    if args.max_images:
        data = data[:args.max_images]
        print(f"Limiting processing count to: {args.max_images}")
    
    # Create output directory
    create_output_directory(args.output_dir)
    
#     degradation_params = {
#     'gt_size': 512,
#     'in_size': 512,
#     'use_motion_kernel': False,
#     'blur_kernel_size': 41,
#     'blur_sigma': [1, 15],
#     'downsample_range': [4, 30],
#     'noise_range': [0, 20],
#     'jpeg_range': [30, 80]
# }
    
    # Create image degradation configuration
    degradation_config = ImageDegradationConfig(
        gt_size=512,
        in_size=512,
        use_motion_kernel=False,
        kernel_list=['iso', 'aniso'],
        kernel_prob=[0.5, 0.5],
        blur_kernel_size=41,
        blur_sigma=[1, 15],
        downsample_range=[4, 30],
        noise_range=[0, 20],
        jpeg_range=[30, 80]
    )
    
    # Process images and collect results
    results = []
    failed_count = 0
    
    print("Starting image processing...")
    for i, item in enumerate(tqdm(data, desc="Processing images")):
        image_path = item.get('image', '')
        caption = item.get('caption', '')
        
        if not image_path or not caption:
            print(f"Skipping invalid data item {i}: missing image or caption field")
            failed_count += 1
            continue
        
        result = process_single_image(
            image_path=image_path,
            caption=caption,
            output_dir=args.output_dir,
            degradation_config=degradation_config,
            index=i
        )
        
        if result:
            results.append(result)
        else:
            failed_count += 1
    
    # Save JSONL file
    print(f"Saving JSONL file to: {args.output_jsonl}")
    try:
        with jsonlines.open(args.output_jsonl, 'w') as writer:
            for result in results:
                writer.write(result)
        
        print(f"Processing complete!")
        print(f"Successfully processed: {len(results)} images")
        print(f"Failed: {failed_count} images")
        print(f"JSONL file saved: {args.output_jsonl}")
        
    except Exception as e:
        print(f"Failed to save JSONL file: {e}")


if __name__ == "__main__":
    main()