Pref-Restoration / trl / inference.py
inference.py
Raw
from dataclasses import dataclass
import torch
from PIL import Image
from transformers import AutoTokenizer
from blip3o.model import *
import os


@dataclass
class T2IConfig:
    model_path: str = "/fsx/home/jiuhai.chen/BLIP3o-NEXT/models/debug"
    device: str = "cuda:0"
    dtype: torch.dtype = torch.bfloat16
    # generation config
    scale: int = 0  
    seq_len: int = 729  
    top_p: float = 0.95
    top_k: int = 1200

class TextToImageInference:
    def __init__(self, config: T2IConfig):
        self.config = config
        self.device = torch.device(config.device)
        self._load_models()
        
    def _load_models(self):
        self.model = blip3oQwenForInferenceLM.from_pretrained(self.config.model_path, torch_dtype=self.config.dtype).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)

    def generate_image(self, prompt: str) -> Image.Image:

        batch_messages = []


        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": f"Please generate image based on the following caption: {prompt}"}
        ]
        input_text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True)
        input_text += f"<im_start><S{self.config.scale}>"
        
        batch_messages.append(input_text)

        # tokenize as a batch
        inputs = self.tokenizer(batch_messages, return_tensors="pt", padding=True, truncation=True, padding_side="left")
    
        gen_ids, output_image = self.model.generate_images(
            inputs.input_ids.to(self.device),
            inputs.attention_mask.to(self.device),
            max_new_tokens=self.config.seq_len,
            do_sample=True,
            top_p=self.config.top_p,
            top_k=self.config.top_k)

        print(output_image)
        return output_image[0]


def main():
    config = T2IConfig()
    inference = TextToImageInference(config)

    prompts = [
        'A cute cat'
    ]   

    output_dir = "BLIP3o-NEXT"
    os.makedirs(output_dir, exist_ok=True)

    for idx, prompt in enumerate(prompts):
        image_sana = inference.generate_image(prompt)
        save_path = os.path.join(output_dir, f"blip3o_next_{idx:02d}.png")
        image_sana.save(save_path)
        print(f"Saved: {save_path}")



if __name__ == "__main__":
    main()