Pref-Restoration / trl / train_grpo.py
train_grpo.py
Raw
from datasets import load_dataset, Dataset
from trl import GRPOConfig, GRPOTrainer


with open("ocr_train.txt", "r", encoding="utf-8") as f:
    prompts = [line.strip() for line in f if line.strip()]

train_dataset = Dataset.from_dict({"prompt": prompts})
train_dataset.shuffle(seed=0)


training_args = GRPOConfig(output_dir="BLIP3o-NEXT-Text-GRPO", use_liger_loss=True, per_device_train_batch_size=16, num_generations=16, save_steps=50, lr_scheduler_type="cosine", learning_rate=1e-6, beta=0.001)

print(training_args)

## dummy reward for testing
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

trainer = GRPOTrainer(
    model="/fsx/home/jiuhai.chen/BLIP3o-NEXT/models/debug",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=train_dataset,
)
trainer.train()