""" Text-to-Image Pipeline for ART-FR Model Interface similar to diffusers pipeline, convenient for calling TextToImageInference """ import os import sys import torch from PIL import Image from dataclasses import dataclass from typing import Union, Optional, List from torchvision.transforms import v2 from transformers import AutoTokenizer # Assuming these are modules in your project, paths may need adjustments repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) if repo_root not in sys.path: sys.path.insert(0, repo_root) from blip3o.model.language_model.blip3o_qwen_inference_vae import blip3oQwenForInferenceLMVAE IMAGE_TOKEN_INDEX = -200 @dataclass class PipelineConfig: """Pipeline configuration class""" model_path: str = "/data/phd/yaozhengjian/zjYao_Exprs/BLIP-3o-next/Face-Restore_restoration-FFHQ+CelebA/checkpoint-30800" device: str = "cuda:0" dtype: torch.dtype = torch.bfloat16 # Generation config scale: int = 0 seq_len: int = 729 top_p: float = 0.95 top_k: int = 1200 # Image processing config image_size: int = 512 class PrefRestorePipeline: """ Text-to-image generation pipeline similar to diffusers pipeline Usage example: ```python from DiffusionNFT.model.text_to_image_pipeline import PrefRestorePipeline # Initialize pipeline pipeline = PrefRestorePipeline.from_pretrained(model_path="your_model_path") # Generate image result = pipeline( prompt="Please reconstruct the given image.", image="path/to/image.jpg", num_inference_steps=50 ) # Save result result.images[0].save("output.jpg") ``` """ def __init__(self, config: PipelineConfig): self.config = config self.device = torch.device(config.device) self._setup_transforms() self._load_models() @classmethod def from_pretrained( cls, model_path: str, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, **kwargs ): """Create pipeline from pre-trained model""" config = PipelineConfig( model_path=model_path, device=device, dtype=dtype, **kwargs ) return cls(config) def _setup_transforms(self): """Set up image transforms""" self.target_transform = v2.Compose([ v2.Resize(self.config.image_size), v2.CenterCrop(self.config.image_size), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize([0.5], [0.5]), ]) def _load_models(self): """Load model and tokenizer""" self.model = blip3oQwenForInferenceLMVAE.from_pretrained( self.config.model_path, torch_dtype=self.config.dtype ).to(self.device) self.transformer = self.model.model.sana.transformer_blocks self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path) self.processor = self.model.get_vision_tower().image_processor def _process_image(self, image: Union[str, Image.Image, List[Union[str, Image.Image]]]) -> tuple: """Process input image (supports single image or list of images)""" if isinstance(image, list): # Process image list processed_images = [] image_sizes = [] original_images = [] for img in image: if isinstance(img, str): img = Image.open(img).convert("RGB") elif not isinstance(img, Image.Image): raise ValueError("Each image must be a PIL Image or file path") image_size = img.size processed_image = self.processor.preprocess(img, return_tensors="pt")["pixel_values"][0] processed_images.append(processed_image) image_sizes.append(image_size) original_images.append(img) return processed_images, image_sizes, original_images else: # Process single image (keep original logic) if isinstance(image, str): image = Image.open(image).convert("RGB") elif not isinstance(image, Image.Image): raise ValueError("image must be a PIL Image or file path") image_size = image.size processed_image = self.processor.preprocess(image, return_tensors="pt")["pixel_values"][0] return processed_image, image_size, image def _process_target_image(self, image: Union[Image.Image, List[Image.Image]]): """Process target image (supports single image or list of images)""" if isinstance(image, list): # Process image list processed_images = [] for img in image: if not isinstance(img, Image.Image): raise ValueError("Each target image must be a PIL Image") processed_img = self.target_transform(img) processed_images.append(processed_img) return processed_images else: # Process single image if not isinstance(image, Image.Image): raise ValueError("Target image must be a PIL Image") return self.target_transform(image) def _preprocess_qwen(self, sources, has_image: bool = True, max_len=2048, system_message: str = "You are a helpful assistant."): """Preprocess Qwen input""" roles = {"human": "user", "gpt": "assistant"} if 'image_token_index' not in globals(): self.tokenizer.add_tokens([""], special_tokens=True) global image_token_index image_token_index = self.tokenizer.convert_tokens_to_ids("") im_start, im_end = self.tokenizer.additional_special_tokens_ids[:2] unmask_tokens_idx = [198, im_start, im_end] chat_template = ( "{% for message in messages %}" "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" "{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" ) self.tokenizer.chat_template = chat_template input_ids, targets = [], [] for source in sources: if roles[source[0]["from"]] != roles["human"]: source = source[1:] input_id, target = [], [] input_id += self.tokenizer.apply_chat_template([{"role": "system", "content": system_message}]) target += input_id for conv in source: try: role = conv["role"] content = conv["content"] except: role = conv["from"] content = conv["value"] role = roles.get(role, role) conv = [{"role": role, "content": content}] encode_id = self.tokenizer.apply_chat_template(conv) if role == roles["human"]: input_id += encode_id target += encode_id else: input_id += encode_id[:-2] target += encode_id[:-2] assert len(input_id) == len(target) for idx, encode_id in enumerate(input_id): if encode_id in unmask_tokens_idx: target[idx] = encode_id if encode_id == image_token_index: input_id[idx] = IMAGE_TOKEN_INDEX input_ids.append(input_id) targets.append(target) input_ids = torch.tensor(input_ids, dtype=torch.long) targets = torch.tensor(targets, dtype=torch.long) return dict(input_ids=input_ids, labels=targets) def __call__( self, prompt: str = "Please reconstruct the given image.", image: Union[str, Image.Image] = None, num_inference_steps: Optional[int] = None, guidance_scale: Optional[float] = None, **kwargs ): """ Main interface for image generation Args: prompt: Text prompt image: Input image (path or PIL Image object) num_inference_steps: Number of inference steps (not yet used, kept for compatibility) guidance_scale: Guidance scale (not yet used, kept for compatibility) **kwargs: Other parameters Returns: Result object containing generated images """ if image is None: raise ValueError("image is required for this pipeline") # Process image processed_image, image_size, original_image = self._process_image(image) detailed_condition = self._process_target_image(original_image) # Prepare messages messages = [ {"from": "human", "value": "\nPlease reconstruct the given image."}, {"from": "gpt", "value": f""} ] # Preprocess input data_dict = self._preprocess_qwen([messages], has_image=True) inputs = data_dict['input_ids'] # Generate image with torch.no_grad(): output_images = self.model.generate_images_from_image( inputs.to(self.device), images=[processed_image], detailed_conditions=[detailed_condition], max_new_tokens=self.config.seq_len, do_sample=True, top_p=self.config.top_p, top_k=self.config.top_k, ) # Return results return PipelineResult( images=output_images, original_image=original_image, prompt=prompt ) class PipelineResult: """Pipeline result class""" def __init__(self, images: List[Image.Image], original_image: Image.Image, prompt: str): self.images = images self.original_image = original_image self.prompt = prompt def save(self, output_path: str, index: int = 0): """Save generated image""" if index < len(self.images): self.images[index].save(output_path) else: raise IndexError(f"Index {index} out of range. Only {len(self.images)} images available.") # Convenience functions def load_pipeline(model_path: str, **kwargs) -> PrefRestorePipeline: """Convenience function to load pipeline""" return PrefRestorePipeline.from_pretrained(model_path, **kwargs)