import logging
import pathlib
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import deepspeed
import torch
import transformers
from transformers import AutoConfig, AutoTokenizer
from deepspeed.runtime.fp16.loss_scaler import LossScaler
from blip3o.data import make_supervised_data_module
from blip3o.model import blip3oQwenForCausalLM
from blip3o.train.blip3o_trainer import blip3oTrainer
from blip3o.utils import rank0_print
from tabulate import tabulate
import numpy as np
torch.serialization.add_safe_globals([np._core.multiarray._reconstruct])
torch.multiprocessing.set_sharing_strategy("file_system")
torch.serialization.add_safe_globals([LossScaler])
local_rank = None
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
diffusion_name_or_path: Optional[str] = field(default="facebook/opt-125m")
model_class_name: Optional[str] = field(default=None, metadata={"help": "Used to init model class, format is XXXXForCausalLM. e.g. currently XXXX is chosen from blip3oLlama, blip3oMixtral, blip3oMistral, Llama"})
mm_tunable_parts: Optional[str] = field(default="mm_language_model")
version: Optional[str] = field(default="v0")
vision_tower: Optional[str] = field(default=None)
vision_tower_pretrained: Optional[str] = field(default=None) # default to the last layer
mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
mm_use_im_start_end: bool = field(default=False)
mm_patch_merge_type: Optional[str] = field(default="flat")
mm_vision_select_feature: Optional[str] = field(default="patch")
rope_scaling_factor: Optional[float] = field(default=None)
rope_scaling_type: Optional[str] = field(default=None)
use_pos_skipping: Optional[bool] = field(default=False)
pos_skipping_range: Optional[int] = field(default=4096)
delay_load: Optional[bool] = field(default=True)
num_image_tokens: Optional[int] = field(default=-1)
image_token_format: str = field(default="<I{}>")
num_scale_tokens: Optional[int] = field(default=3)
scale_token_format: str = field(default="<S{}>")
load_embeddings_from_vision: Optional[bool] = field(default=False)
@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data, in blip3o's instruction.json format. Supporting multiple json files via /path/to/{a,b,c}.json"})
lazy_preprocess: bool = False
is_multimodal: bool = False
early_mix_text: bool = False
image_folder: Optional[str] = field(default=None)
image_aspect_ratio: str = "square"
dataset_cls: str = field(default="blip3o")
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
remove_unused_columns: bool = field(default=False)
mpt_attn_impl: Optional[str] = field(default="triton")
model_max_length: int = field(
default=4096,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
mm_vision_tower_lr: Optional[float] = None
group_by_varlen: bool = field(default=False)
group_by_modality_length: bool = field(default=False)
group_by_modality_length_auto: bool = field(default=False)
auto_find_batch_size: bool = field(default=False)
gradient_checkpointing: bool = field(default=True)
attn_implementation: str = field(default="flash_attention_2", metadata={"help": "Use transformers attention implementation."})
dispatch_batches: Optional[bool] = field(default=None)
split_batches: Optional[bool] = field(default=None)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
trainer.accelerator.wait_for_everyone()
torch.cuda.synchronize()
if trainer.deepspeed:
trainer.save_model(output_dir)
return
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def get_model(model_args, training_args):
customized_kwargs = {}
overwrite_config = {}
cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path) # blip3oQwenConfig has been registered
'''——————————————————————————————————————Default not used——————————————————————————————————————————'''
if model_args.use_pos_skipping is not None and model_args.pos_skipping_range is not None:
overwrite_config["use_pos_skipping"] = model_args.use_pos_skipping
overwrite_config["pos_skipping_range"] = model_args.pos_skipping_range
'''
1. Expand the model's context window
This is the main function of this code. Many pre-trained language models (like Llama, Mistral) have a fixed maximum sequence length (e.g., 4096 tokens). If you want the model to handle longer text, you must use techniques to expand its context window.
The rope_scaling part in the code is for this purpose:
RoPE (Rotary Position Embedding) is an advanced position encoding technique.
RoPE Scaling is a fine-tuning technique that allows the model to understand sequences longer than the original training length by adjusting the way RoPE is calculated.
Actual operation: When you pass rope_scaling_factor (e.g., 4.0) and rope_scaling_type (e.g., "linear" or "dynamic") during training startup, this code will automatically calculate the new model_max_length (e.g., 4096 * 4.0 = 16384) and update these configurations into the model. This way, the loaded model can handle 16384-length sequences during fine-tuning.
'''
if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None:
overwrite_config["rope_scaling"] = {
"factor": model_args.rope_scaling_factor,
"type": model_args.rope_scaling_type,
}
if training_args.model_max_length is None:
training_args.model_max_length = cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor
overwrite_config["max_sequence_length"] = training_args.model_max_length
assert training_args.model_max_length == int(cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor), print(
f"model_max_length: {training_args.model_max_length}, max_position_embeddings: {cfg_pretrained.max_position_embeddings}, rope_scaling_factor: {model_args.rope_scaling_factor}"
)
'''——————————————————————————————————————————————————————————————————————————————————————————'''
if overwrite_config:
assert cfg_pretrained is not None, "cfg_pretrained is None"
rank0_print(f"Overwriting config with {overwrite_config}")
for k, v in overwrite_config.items():
setattr(cfg_pretrained, k, v)
customized_kwargs["config"] = cfg_pretrained
model = blip3oQwenForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=training_args.attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
low_cpu_mem_usage=False,
**customized_kwargs)
return model
def train():
global local_rank
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
local_rank = training_args.local_rank
model = get_model(model_args, training_args)
model.config.use_cache = False
if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None:
model.config.rope_scaling = {
"factor": model_args.rope_scaling_factor,
"type": model_args.rope_scaling_type,
}
if training_args.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"): # This logic is taken
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right")
if tokenizer.unk_token is not None: # This logic is not taken
tokenizer.pad_token = tokenizer.unk_token # Qwen\Blip-3o: "pad_token": "<|endoftext|>", "unk_token": null
if model_args.vision_tower is not None:
model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp) # Qwen loads other modules in this step during pre-training: ta-tok, sana, sana_vae, diffusion_connector
vision_tower = model.get_vision_tower() # TA-Tok
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
data_args.image_processor = vision_tower.image_processor
data_args.is_multimodal = True
model.config.image_aspect_ratio = data_args.image_aspect_ratio # square
model.config.diffusion_name_or_path = model_args.diffusion_name_or_path
model.config.tokenizer_padding_side = tokenizer.padding_side # right
model.config.tokenizer_model_max_length = tokenizer.model_max_length # Passed via .sh 2048
model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end # True
model.config.mm_vision_tower_lr = training_args.mm_vision_tower_lr # None
training_args.use_im_start_end = model_args.mm_use_im_start_end # True
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) # Add new image token, scale token, image token, start_end token to tokenizer
### Deciding which part of the model to train
rank0_print(f"Using mm_tunable_parts: {model_args.mm_tunable_parts}") # mm_language_model trains other parts except vision tower
model.config.mm_tunable_parts = training_args.mm_tunable_parts = model_args.mm_tunable_parts
# Set the entire model to not require gradients by default
model.requires_grad_(False)
vision_tower.requires_grad_(False)
vision_tower.eval()
# Parse the mm_tunable_parts to decide which parts to unfreeze
tunable_parts = model_args.mm_tunable_parts.split(",")
if "mm_vision_tower" in tunable_parts:
for name, param in model.named_parameters():
if "vision_tower" in name:
param.requires_grad_(True)
if "mm_language_model" in tunable_parts:
for name, param in model.named_parameters():
if "vision_tower" not in name:
param.requires_grad_(True)
if 'mm_embedding' in tunable_parts:
for name, param in model.named_parameters():
if "embed_tokens" in name or 'lm_head' in name:
param.requires_grad_(True)
## freeze sana except the caption projection
for name, param in model.named_parameters(): # sana is not trained to maintain original generation capability
if "sana" in name:
param.requires_grad_(False)
for name, param in model.named_parameters():
if "caption" in name:
param.requires_grad_(True)
# Summary of trainable parameters ~1874.48 MB Qwen Embedding, Qwen layers, sana caption projection, diffusion_connector
total_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters())
trainable_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters() if p.requires_grad)
rank0_print(f"Total parameters: ~{total_params/1e6:.2f} MB)")
rank0_print(f"Trainable parameters: ~{trainable_params/1e6:.2f} MB)")
# for name, p in model.named_parameters():
# if p.requires_grad:
# rank0_print(f"Trainable parameter: {name}")
# Create a file to save trainable parameter information
trainable_params_file = "trainable_parameters.txt"
with open(trainable_params_file, "w") as f:
f.write("=== Trainable Parameters ===\n\n")
total_params = 0
for name, p in model.named_parameters():
if p.requires_grad:
param_count = p.numel()
param_info = f"Parameter: {name}\n Shape: {list(p.shape)}\n Count: {param_count:,}\n\n"
f.write(param_info)
total_params += param_count
f.write(f"=== Summary ===\n")
f.write(f"Total trainable parameters: {total_params:,}\n")
f.write(f"Total trainable parameters (M): {total_params/1e6:.2f}M\n")
rank0_print(f"Trainable parameters saved to: {trainable_params_file}")
rank0_print(f"Total trainable parameters: {total_params:,} ({total_params/1e6:.2f}M)")
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = blip3oTrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
if trainer.is_world_process_zero():
stat = []
for i, (n, p) in enumerate(trainer.model.named_parameters()):
stat.append([i, n, p.shape, p.requires_grad])
print(tabulate(stat, headers=["idx", "name", "shape", "trainable"]))
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
model.config.use_cache = True
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
rank0_print(f"Model saved to {training_args.output_dir}")
if __name__ == "__main__":
train()