from datasets import load_dataset, Dataset
from trl import GRPOConfig, GRPOTrainer
with open("ocr_train.txt", "r", encoding="utf-8") as f:
prompts = [line.strip() for line in f if line.strip()]
train_dataset = Dataset.from_dict({"prompt": prompts})
train_dataset.shuffle(seed=0)
training_args = GRPOConfig(output_dir="BLIP3o-NEXT-Text-GRPO", use_liger_loss=True, per_device_train_batch_size=16, num_generations=16, save_steps=50, lr_scheduler_type="cosine", learning_rate=1e-6, beta=0.001)
print(training_args)
## dummy reward for testing
def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]
trainer = GRPOTrainer(
model="/fsx/home/jiuhai.chen/BLIP3o-NEXT/models/debug",
reward_funcs=reward_len,
args=training_args,
train_dataset=train_dataset,
)
trainer.train()