Pref-Restoration / DiffusionNFT / model / PrefRestorePipeline_pipeline.py
PrefRestorePipeline_pipeline.py
Raw
"""
Text-to-Image Pipeline for ART-FR Model
Interface similar to diffusers pipeline, convenient for calling TextToImageInference
"""
import os
import sys
import torch
from PIL import Image
from dataclasses import dataclass
from typing import Union, Optional, List
from torchvision.transforms import v2
from transformers import AutoTokenizer

# Assuming these are modules in your project, paths may need adjustments
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)
from blip3o.model.language_model.blip3o_qwen_inference_vae import blip3oQwenForInferenceLMVAE


IMAGE_TOKEN_INDEX = -200


@dataclass
class PipelineConfig:
    """Pipeline configuration class"""
    model_path: str = "/data/phd/yaozhengjian/zjYao_Exprs/BLIP-3o-next/Face-Restore_restoration-FFHQ+CelebA/checkpoint-30800"
    device: str = "cuda:0"
    dtype: torch.dtype = torch.bfloat16
    # Generation config
    scale: int = 0  
    seq_len: int = 729  
    top_p: float = 0.95
    top_k: int = 1200
    # Image processing config
    image_size: int = 512


class PrefRestorePipeline:
    """
    Text-to-image generation pipeline similar to diffusers pipeline
    
    Usage example:
    ```python
    from DiffusionNFT.model.text_to_image_pipeline import PrefRestorePipeline
    
    # Initialize pipeline
    pipeline = PrefRestorePipeline.from_pretrained(model_path="your_model_path")
    
    # Generate image
    result = pipeline(
        prompt="Please reconstruct the given image.",
        image="path/to/image.jpg",
        num_inference_steps=50
    )
    
    # Save result
    result.images[0].save("output.jpg")
    ```
    """
    
    def __init__(self, config: PipelineConfig):
        self.config = config
        self.device = torch.device(config.device)
        self._setup_transforms()
        self._load_models()
        
    @classmethod
    def from_pretrained(
        cls, 
        model_path: str,
        device: str = "cuda:0",
        dtype: torch.dtype = torch.bfloat16,
        **kwargs
    ):
        """Create pipeline from pre-trained model"""
        config = PipelineConfig(
            model_path=model_path,
            device=device,
            dtype=dtype,
            **kwargs
        )
        return cls(config)
    
    def _setup_transforms(self):
        """Set up image transforms"""
        self.target_transform = v2.Compose([
            v2.Resize(self.config.image_size),
            v2.CenterCrop(self.config.image_size),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize([0.5], [0.5]),
        ])
        
    def _load_models(self):
        """Load model and tokenizer"""
        self.model = blip3oQwenForInferenceLMVAE.from_pretrained(
            self.config.model_path, 
            torch_dtype=self.config.dtype
        ).to(self.device)
        self.transformer = self.model.model.sana.transformer_blocks
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
        self.processor = self.model.get_vision_tower().image_processor
        
    def _process_image(self, image: Union[str, Image.Image, List[Union[str, Image.Image]]]) -> tuple:
        """Process input image (supports single image or list of images)"""
        if isinstance(image, list):
            # Process image list
            processed_images = []
            image_sizes = []
            original_images = []
            
            for img in image:
                if isinstance(img, str):
                    img = Image.open(img).convert("RGB")
                elif not isinstance(img, Image.Image):
                    raise ValueError("Each image must be a PIL Image or file path")
                
                image_size = img.size
                processed_image = self.processor.preprocess(img, return_tensors="pt")["pixel_values"][0]
                
                processed_images.append(processed_image)
                image_sizes.append(image_size)
                original_images.append(img)
            
            return processed_images, image_sizes, original_images
        else:
            # Process single image (keep original logic)
            if isinstance(image, str):
                image = Image.open(image).convert("RGB")
            elif not isinstance(image, Image.Image):
                raise ValueError("image must be a PIL Image or file path")
                
            image_size = image.size
            processed_image = self.processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
            return processed_image, image_size, image
    
    def _process_target_image(self, image: Union[Image.Image, List[Image.Image]]):
        """Process target image (supports single image or list of images)"""
        if isinstance(image, list):
            # Process image list
            processed_images = []
            for img in image:
                if not isinstance(img, Image.Image):
                    raise ValueError("Each target image must be a PIL Image")
                processed_img = self.target_transform(img)
                processed_images.append(processed_img)
            return processed_images
        else:
            # Process single image
            if not isinstance(image, Image.Image):
                raise ValueError("Target image must be a PIL Image")
            return self.target_transform(image)
    
    def _preprocess_qwen(self, sources, has_image: bool = True, max_len=2048,
                        system_message: str = "You are a helpful assistant."):
        """Preprocess Qwen input"""
        roles = {"human": "user", "gpt": "assistant"}

        if 'image_token_index' not in globals():
            self.tokenizer.add_tokens(["<image>"], special_tokens=True)
            global image_token_index
            image_token_index = self.tokenizer.convert_tokens_to_ids("<image>") 

        im_start, im_end = self.tokenizer.additional_special_tokens_ids[:2]
        unmask_tokens_idx = [198, im_start, im_end]
        chat_template = (
            "{% for message in messages %}"
            "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
            "{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
        )
        self.tokenizer.chat_template = chat_template

        input_ids, targets = [], []
        for source in sources:
            if roles[source[0]["from"]] != roles["human"]:
                source = source[1:]

            input_id, target = [], []
            input_id += self.tokenizer.apply_chat_template([{"role": "system", "content": system_message}])
            target += input_id

            for conv in source:
                try:
                    role = conv["role"]
                    content = conv["content"]
                except:
                    role = conv["from"]
                    content = conv["value"]
                role = roles.get(role, role)
                conv = [{"role": role, "content": content}]
                encode_id = self.tokenizer.apply_chat_template(conv)

                if role == roles["human"]:
                    input_id += encode_id
                    target += encode_id
                else:
                    input_id += encode_id[:-2]
                    target += encode_id[:-2]

            assert len(input_id) == len(target)
            for idx, encode_id in enumerate(input_id):
                if encode_id in unmask_tokens_idx:
                    target[idx] = encode_id
                if encode_id == image_token_index:
                    input_id[idx] = IMAGE_TOKEN_INDEX

            input_ids.append(input_id)
            targets.append(target)

        input_ids = torch.tensor(input_ids, dtype=torch.long)
        targets = torch.tensor(targets, dtype=torch.long)
        return dict(input_ids=input_ids, labels=targets)
    
    def __call__(
        self,
        prompt: str = "Please reconstruct the given image.",
        image: Union[str, Image.Image] = None,
        num_inference_steps: Optional[int] = None,
        guidance_scale: Optional[float] = None,
        **kwargs
    ):
        """
        Main interface for image generation
        
        Args:
            prompt: Text prompt
            image: Input image (path or PIL Image object)
            num_inference_steps: Number of inference steps (not yet used, kept for compatibility)
            guidance_scale: Guidance scale (not yet used, kept for compatibility)
            **kwargs: Other parameters
            
        Returns:
            Result object containing generated images
        """
        if image is None:
            raise ValueError("image is required for this pipeline")
            
        # Process image
        processed_image, image_size, original_image = self._process_image(image)
        detailed_condition = self._process_target_image(original_image)
        
        # Prepare messages
        messages = [
            {"from": "human", "value": "<image>\nPlease reconstruct the given image."},
            {"from": "gpt", "value": f"<im_start><S{self.config.scale}>"}
        ]
        
        # Preprocess input
        data_dict = self._preprocess_qwen([messages], has_image=True)
        inputs = data_dict['input_ids']
        
        # Generate image
        with torch.no_grad():
            output_images = self.model.generate_images_from_image(
                inputs.to(self.device),
                images=[processed_image],
                detailed_conditions=[detailed_condition],
                max_new_tokens=self.config.seq_len,
                do_sample=True,
                top_p=self.config.top_p,
                top_k=self.config.top_k,
            )
        
        # Return results
        return PipelineResult(
            images=output_images,
            original_image=original_image,
            prompt=prompt
        )


class PipelineResult:
    """Pipeline result class"""
    def __init__(self, images: List[Image.Image], original_image: Image.Image, prompt: str):
        self.images = images
        self.original_image = original_image
        self.prompt = prompt
        
    def save(self, output_path: str, index: int = 0):
        """Save generated image"""
        if index < len(self.images):
            self.images[index].save(output_path)
        else:
            raise IndexError(f"Index {index} out of range. Only {len(self.images)} images available.")


# Convenience functions
def load_pipeline(model_path: str, **kwargs) -> PrefRestorePipeline:
    """Convenience function to load pipeline"""
    return PrefRestorePipeline.from_pretrained(model_path, **kwargs)