from dataclasses import dataclass import torch from PIL import Image from transformers import AutoTokenizer from blip3o.model import * import os @dataclass class T2IConfig: model_path: str = "/fsx/home/jiuhai.chen/BLIP3o-NEXT/models/debug" device: str = "cuda:0" dtype: torch.dtype = torch.bfloat16 # generation config scale: int = 0 seq_len: int = 729 top_p: float = 0.95 top_k: int = 1200 class TextToImageInference: def __init__(self, config: T2IConfig): self.config = config self.device = torch.device(config.device) self._load_models() def _load_models(self): self.model = blip3oQwenForInferenceLM.from_pretrained(self.config.model_path, torch_dtype=self.config.dtype).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path) def generate_image(self, prompt: str) -> Image.Image: batch_messages = [] messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": f"Please generate image based on the following caption: {prompt}"} ] input_text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True) input_text += f"" batch_messages.append(input_text) # tokenize as a batch inputs = self.tokenizer(batch_messages, return_tensors="pt", padding=True, truncation=True, padding_side="left") gen_ids, output_image = self.model.generate_images( inputs.input_ids.to(self.device), inputs.attention_mask.to(self.device), max_new_tokens=self.config.seq_len, do_sample=True, top_p=self.config.top_p, top_k=self.config.top_k) print(output_image) return output_image[0] def main(): config = T2IConfig() inference = TextToImageInference(config) prompts = [ 'A cute cat' ] output_dir = "BLIP3o-NEXT" os.makedirs(output_dir, exist_ok=True) for idx, prompt in enumerate(prompts): image_sana = inference.generate_image(prompt) save_path = os.path.join(output_dir, f"blip3o_next_{idx:02d}.png") image_sana.save(save_path) print(f"Saved: {save_path}") if __name__ == "__main__": main()