""" Test script: Load PrefRestorePipeline and print model structure """ import sys import os import torch # Add project path project_root = "/data/phd/yaozhengjian/Code/RL/ART-FRv2/DiffusionNFT" if project_root not in sys.path: sys.path.insert(0, project_root) # from DiffusionNFT.model.PrefRestorePipeline_pipeline import PrefRestorePipeline from PrefRestorePipeline_pipeline import PrefRestorePipeline def print_model_structure(): """Print model structure to file""" # Configure model path (please modify according to actual situation) model_path = "/data/phd/yaozhengjian/zjYao_Exprs/BLIP-3o-next/Face-Restoration_FFHQ_VAE_Step3_scaling/checkpoint-108000" print("Loading PrefRestorePipeline...") pipeline = PrefRestorePipeline.from_pretrained( model_path=model_path, device="cuda:0" if torch.cuda.is_available() else "cpu", dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 ) print("Pipeline loaded successfully!") # Create detailed model structure report output_file = "PrefRestorePipeline_model_structure.txt" with open(output_file, "w", encoding="utf-8") as f: f.write("=" * 100 + "\n") f.write("PrefRestorePipeline Model Structure Detailed Analysis\n") f.write("=" * 100 + "\n\n") # 1. Basic Information f.write("1. Basic Information\n") f.write("-" * 50 + "\n") f.write(f"Model Path: {model_path}\n") f.write(f"Device: {pipeline.device}\n") f.write(f"Data Type: {pipeline.config.dtype}\n") f.write(f"Main Model Type: {type(pipeline.model).__name__}\n") f.write(f"Tokenizer Type: {type(pipeline.tokenizer).__name__}\n") f.write(f"Processor Type: {type(pipeline.processor).__name__}\n\n") # 2. Model Hierarchy f.write("2. Model Hierarchy\n") f.write("-" * 50 + "\n") for name, module in pipeline.model.named_modules(): f.write(f"{name}: {type(module).__name__}\n") f.write("\n") # 3. Parameter Statistics f.write("3. Parameter Statistics\n") f.write("-" * 50 + "\n") total_params = sum(p.numel() for p in pipeline.model.parameters()) trainable_params = sum(p.numel() for p in pipeline.model.parameters() if p.requires_grad) frozen_params = total_params - trainable_params f.write(f"Total Parameters: {total_params:,}\n") f.write(f"Trainable Parameters: {trainable_params:,}\n") f.write(f"Frozen Parameters: {frozen_params:,}\n") f.write(f"Trainable Ratio: {trainable_params/total_params*100:.2f}%\n\n") # 4. Detailed Parameter Breakdown by Module f.write("4. Detailed Parameter Breakdown by Module\n") f.write("-" * 50 + "\n") module_stats = {} for name, module in pipeline.model.named_modules(): if len(list(module.children())) == 0: # Leaf nodes total = sum(p.numel() for p in module.parameters()) trainable = sum(p.numel() for p in module.parameters() if p.requires_grad) if total > 0: module_stats[name] = { 'total': total, 'trainable': trainable, 'frozen': total - trainable, 'trainable_ratio': trainable / total * 100 if total > 0 else 0 } # Sort by parameter count sorted_modules = sorted(module_stats.items(), key=lambda x: x[1]['total'], reverse=True) for name, stats in sorted_modules: f.write(f"\n{name}:\n") f.write(f" Total Parameters: {stats['total']:,}\n") f.write(f" Trainable: {stats['trainable']:,}\n") f.write(f" Frozen: {stats['frozen']:,}\n") f.write(f" Trainable Ratio: {stats['trainable_ratio']:.2f}%\n") # 5. Suggested Freezing Strategy f.write("\n\n5. Suggested Freezing Strategy\n") f.write("-" * 50 + "\n") f.write("# Freeze suggestions based on module names:\n\n") freeze_suggestions = [] for name, module in pipeline.model.named_modules(): name_lower = name.lower() # Common modules that often need freezing if any(keyword in name_lower for keyword in [ 'vision_tower', 'image_processor', 'vision_model', 'embeddings', 'encoder', 'layernorm', 'norm', 'position_embedding', 'patch_embedding' ]): freeze_suggestions.append(name) if freeze_suggestions: f.write("# Suggested modules to freeze:\n") for suggestion in freeze_suggestions: f.write(f"pipeline.model.{suggestion}.requires_grad_(False)\n") else: f.write("# No obvious modules to freeze found, please set manually as needed\n") # 6. Training Suggestions f.write("\n\n6. Training Suggestions\n") f.write("-" * 50 + "\n") f.write("# Training suggestions based on model structure:\n\n") if 'vision' in str(pipeline.model).lower(): f.write("- Vision module detected, suggested to freeze pre-trained vision encoder\n") if 'language' in str(pipeline.model).lower() or 'llm' in str(pipeline.model).lower(): f.write("- Language module detected, suggested to freeze pre-trained language model backbone\n") if 'vae' in str(pipeline.model).lower(): f.write("- VAE module detected, suggested to freeze VAE encoder and decoder\n") f.write(f"\nSuggested to only train specific adaptation layers or newly added modules\n") f.write(f"High total parameter count({total_params:,}), suggested to use parameter-efficient training methods like LoRA\n") print(f"Model structure analysis saved to: {output_file}") # 7. Print key info to console print("\n" + "="*60) print("Model Structure Overview:") print("="*60) print(f"Total Parameters: {total_params:,}") print(f"Trainable Parameters: {trainable_params:,}") print(f"Frozen Parameters: {frozen_params:,}") print(f"Trainable Ratio: {trainable_params/total_params*100:.2f}%") print(f"\nDetailed analysis saved to: {output_file}") if __name__ == "__main__": print_model_structure()