"""
Test script: Load PrefRestorePipeline and print model structure
"""
import sys
import os
import torch
# Add project path
project_root = "/data/phd/yaozhengjian/Code/RL/ART-FRv2/DiffusionNFT"
if project_root not in sys.path:
sys.path.insert(0, project_root)
# from DiffusionNFT.model.PrefRestorePipeline_pipeline import PrefRestorePipeline
from PrefRestorePipeline_pipeline import PrefRestorePipeline
def print_model_structure():
"""Print model structure to file"""
# Configure model path (please modify according to actual situation)
model_path = "/data/phd/yaozhengjian/zjYao_Exprs/BLIP-3o-next/Face-Restoration_FFHQ_VAE_Step3_scaling/checkpoint-108000"
print("Loading PrefRestorePipeline...")
pipeline = PrefRestorePipeline.from_pretrained(
model_path=model_path,
device="cuda:0" if torch.cuda.is_available() else "cpu",
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
print("Pipeline loaded successfully!")
# Create detailed model structure report
output_file = "PrefRestorePipeline_model_structure.txt"
with open(output_file, "w", encoding="utf-8") as f:
f.write("=" * 100 + "\n")
f.write("PrefRestorePipeline Model Structure Detailed Analysis\n")
f.write("=" * 100 + "\n\n")
# 1. Basic Information
f.write("1. Basic Information\n")
f.write("-" * 50 + "\n")
f.write(f"Model Path: {model_path}\n")
f.write(f"Device: {pipeline.device}\n")
f.write(f"Data Type: {pipeline.config.dtype}\n")
f.write(f"Main Model Type: {type(pipeline.model).__name__}\n")
f.write(f"Tokenizer Type: {type(pipeline.tokenizer).__name__}\n")
f.write(f"Processor Type: {type(pipeline.processor).__name__}\n\n")
# 2. Model Hierarchy
f.write("2. Model Hierarchy\n")
f.write("-" * 50 + "\n")
for name, module in pipeline.model.named_modules():
f.write(f"{name}: {type(module).__name__}\n")
f.write("\n")
# 3. Parameter Statistics
f.write("3. Parameter Statistics\n")
f.write("-" * 50 + "\n")
total_params = sum(p.numel() for p in pipeline.model.parameters())
trainable_params = sum(p.numel() for p in pipeline.model.parameters() if p.requires_grad)
frozen_params = total_params - trainable_params
f.write(f"Total Parameters: {total_params:,}\n")
f.write(f"Trainable Parameters: {trainable_params:,}\n")
f.write(f"Frozen Parameters: {frozen_params:,}\n")
f.write(f"Trainable Ratio: {trainable_params/total_params*100:.2f}%\n\n")
# 4. Detailed Parameter Breakdown by Module
f.write("4. Detailed Parameter Breakdown by Module\n")
f.write("-" * 50 + "\n")
module_stats = {}
for name, module in pipeline.model.named_modules():
if len(list(module.children())) == 0: # Leaf nodes
total = sum(p.numel() for p in module.parameters())
trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
if total > 0:
module_stats[name] = {
'total': total,
'trainable': trainable,
'frozen': total - trainable,
'trainable_ratio': trainable / total * 100 if total > 0 else 0
}
# Sort by parameter count
sorted_modules = sorted(module_stats.items(), key=lambda x: x[1]['total'], reverse=True)
for name, stats in sorted_modules:
f.write(f"\n{name}:\n")
f.write(f" Total Parameters: {stats['total']:,}\n")
f.write(f" Trainable: {stats['trainable']:,}\n")
f.write(f" Frozen: {stats['frozen']:,}\n")
f.write(f" Trainable Ratio: {stats['trainable_ratio']:.2f}%\n")
# 5. Suggested Freezing Strategy
f.write("\n\n5. Suggested Freezing Strategy\n")
f.write("-" * 50 + "\n")
f.write("# Freeze suggestions based on module names:\n\n")
freeze_suggestions = []
for name, module in pipeline.model.named_modules():
name_lower = name.lower()
# Common modules that often need freezing
if any(keyword in name_lower for keyword in [
'vision_tower', 'image_processor', 'vision_model',
'embeddings', 'encoder', 'layernorm', 'norm',
'position_embedding', 'patch_embedding'
]):
freeze_suggestions.append(name)
if freeze_suggestions:
f.write("# Suggested modules to freeze:\n")
for suggestion in freeze_suggestions:
f.write(f"pipeline.model.{suggestion}.requires_grad_(False)\n")
else:
f.write("# No obvious modules to freeze found, please set manually as needed\n")
# 6. Training Suggestions
f.write("\n\n6. Training Suggestions\n")
f.write("-" * 50 + "\n")
f.write("# Training suggestions based on model structure:\n\n")
if 'vision' in str(pipeline.model).lower():
f.write("- Vision module detected, suggested to freeze pre-trained vision encoder\n")
if 'language' in str(pipeline.model).lower() or 'llm' in str(pipeline.model).lower():
f.write("- Language module detected, suggested to freeze pre-trained language model backbone\n")
if 'vae' in str(pipeline.model).lower():
f.write("- VAE module detected, suggested to freeze VAE encoder and decoder\n")
f.write(f"\nSuggested to only train specific adaptation layers or newly added modules\n")
f.write(f"High total parameter count({total_params:,}), suggested to use parameter-efficient training methods like LoRA\n")
print(f"Model structure analysis saved to: {output_file}")
# 7. Print key info to console
print("\n" + "="*60)
print("Model Structure Overview:")
print("="*60)
print(f"Total Parameters: {total_params:,}")
print(f"Trainable Parameters: {trainable_params:,}")
print(f"Frozen Parameters: {frozen_params:,}")
print(f"Trainable Ratio: {trainable_params/total_params*100:.2f}%")
print(f"\nDetailed analysis saved to: {output_file}")
if __name__ == "__main__":
print_model_structure()