Pref-Restoration / DiffusionNFT / model / inspect_model_structure.py
inspect_model_structure.py
Raw
"""
Test script: Load PrefRestorePipeline and print model structure
"""

import sys
import os
import torch

# Add project path
project_root = "/data/phd/yaozhengjian/Code/RL/ART-FRv2/DiffusionNFT"
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# from DiffusionNFT.model.PrefRestorePipeline_pipeline import PrefRestorePipeline
from PrefRestorePipeline_pipeline import PrefRestorePipeline

def print_model_structure():
    """Print model structure to file"""
    
    # Configure model path (please modify according to actual situation)
    model_path = "/data/phd/yaozhengjian/zjYao_Exprs/BLIP-3o-next/Face-Restoration_FFHQ_VAE_Step3_scaling/checkpoint-108000"
    
    print("Loading PrefRestorePipeline...")
    pipeline = PrefRestorePipeline.from_pretrained(
        model_path=model_path,
        device="cuda:0" if torch.cuda.is_available() else "cpu",
        dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
    )
    
    print("Pipeline loaded successfully!")
    
    # Create detailed model structure report
    output_file = "PrefRestorePipeline_model_structure.txt"
    with open(output_file, "w", encoding="utf-8") as f:
        f.write("=" * 100 + "\n")
        f.write("PrefRestorePipeline Model Structure Detailed Analysis\n")
        f.write("=" * 100 + "\n\n")
        
        # 1. Basic Information
        f.write("1. Basic Information\n")
        f.write("-" * 50 + "\n")
        f.write(f"Model Path: {model_path}\n")
        f.write(f"Device: {pipeline.device}\n")
        f.write(f"Data Type: {pipeline.config.dtype}\n")
        f.write(f"Main Model Type: {type(pipeline.model).__name__}\n")
        f.write(f"Tokenizer Type: {type(pipeline.tokenizer).__name__}\n")
        f.write(f"Processor Type: {type(pipeline.processor).__name__}\n\n")
        
        # 2. Model Hierarchy
        f.write("2. Model Hierarchy\n")
        f.write("-" * 50 + "\n")
        for name, module in pipeline.model.named_modules():
            f.write(f"{name}: {type(module).__name__}\n")
        f.write("\n")
        
        # 3. Parameter Statistics
        f.write("3. Parameter Statistics\n")
        f.write("-" * 50 + "\n")
        total_params = sum(p.numel() for p in pipeline.model.parameters())
        trainable_params = sum(p.numel() for p in pipeline.model.parameters() if p.requires_grad)
        frozen_params = total_params - trainable_params
        
        f.write(f"Total Parameters: {total_params:,}\n")
        f.write(f"Trainable Parameters: {trainable_params:,}\n")
        f.write(f"Frozen Parameters: {frozen_params:,}\n")
        f.write(f"Trainable Ratio: {trainable_params/total_params*100:.2f}%\n\n")
        
        # 4. Detailed Parameter Breakdown by Module
        f.write("4. Detailed Parameter Breakdown by Module\n")
        f.write("-" * 50 + "\n")
        
        module_stats = {}
        for name, module in pipeline.model.named_modules():
            if len(list(module.children())) == 0:  # Leaf nodes
                total = sum(p.numel() for p in module.parameters())
                trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
                if total > 0:
                    module_stats[name] = {
                        'total': total,
                        'trainable': trainable,
                        'frozen': total - trainable,
                        'trainable_ratio': trainable / total * 100 if total > 0 else 0
                    }
        
        # Sort by parameter count
        sorted_modules = sorted(module_stats.items(), key=lambda x: x[1]['total'], reverse=True)
        
        for name, stats in sorted_modules:
            f.write(f"\n{name}:\n")
            f.write(f"  Total Parameters: {stats['total']:,}\n")
            f.write(f"  Trainable: {stats['trainable']:,}\n")
            f.write(f"  Frozen: {stats['frozen']:,}\n")
            f.write(f"  Trainable Ratio: {stats['trainable_ratio']:.2f}%\n")
        
        # 5. Suggested Freezing Strategy
        f.write("\n\n5. Suggested Freezing Strategy\n")
        f.write("-" * 50 + "\n")
        f.write("# Freeze suggestions based on module names:\n\n")
        
        freeze_suggestions = []
        for name, module in pipeline.model.named_modules():
            name_lower = name.lower()
            # Common modules that often need freezing
            if any(keyword in name_lower for keyword in [
                'vision_tower', 'image_processor', 'vision_model',
                'embeddings', 'encoder', 'layernorm', 'norm',
                'position_embedding', 'patch_embedding'
            ]):
                freeze_suggestions.append(name)
        
        if freeze_suggestions:
            f.write("# Suggested modules to freeze:\n")
            for suggestion in freeze_suggestions:
                f.write(f"pipeline.model.{suggestion}.requires_grad_(False)\n")
        else:
            f.write("# No obvious modules to freeze found, please set manually as needed\n")
        
        # 6. Training Suggestions
        f.write("\n\n6. Training Suggestions\n")
        f.write("-" * 50 + "\n")
        f.write("# Training suggestions based on model structure:\n\n")
        
        if 'vision' in str(pipeline.model).lower():
            f.write("- Vision module detected, suggested to freeze pre-trained vision encoder\n")
        if 'language' in str(pipeline.model).lower() or 'llm' in str(pipeline.model).lower():
            f.write("- Language module detected, suggested to freeze pre-trained language model backbone\n")
        if 'vae' in str(pipeline.model).lower():
            f.write("- VAE module detected, suggested to freeze VAE encoder and decoder\n")
        
        f.write(f"\nSuggested to only train specific adaptation layers or newly added modules\n")
        f.write(f"High total parameter count({total_params:,}), suggested to use parameter-efficient training methods like LoRA\n")
        
    print(f"Model structure analysis saved to: {output_file}")
    
    # 7. Print key info to console
    print("\n" + "="*60)
    print("Model Structure Overview:")
    print("="*60)
    print(f"Total Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")
    print(f"Frozen Parameters: {frozen_params:,}")
    print(f"Trainable Ratio: {trainable_params/total_params*100:.2f}%")
    
    print(f"\nDetailed analysis saved to: {output_file}")
        


if __name__ == "__main__":
    print_model_structure()