import os
import json
import argparse
from PIL import Image
from tqdm import tqdm
# === Keep your original imports ===
from dataclasses import dataclass
import torch
from transformers import AutoTokenizer
from blip3o.model import *
from blip3o.constants import (
DEFAULT_IM_END_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_TOKEN,
IGNORE_INDEX,
IMAGE_TOKEN_INDEX,
)
from blip3o.data.image_degradation import degrade_image
from torchvision.transforms import v2
# === Keep your defined parameters ===
degradation_params = {
'gt_size': 512,
'in_size': 512,
'use_motion_kernel': False,
'blur_kernel_size': 41,
'blur_sigma': [1, 15],
'downsample_range': [4, 30],
'noise_range': [0, 20],
'jpeg_range': [30, 80]
}
## target transform for sana
target_transform = v2.Compose(
[
v2.Resize(512),
v2.CenterCrop(512),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize([0.5], [0.5]),
]
)
@dataclass
class T2IConfig:
model_path: str = "/data/phd/yaozhengjian/zjYao_Exprs/BLIP-3o-next/Face-Restore_restoration-FFHQ+CelebA/checkpoint-30800"
device: str = "cuda:0"
dtype: torch.dtype = torch.bfloat16
# generation config
scale: int = 0
seq_len: int = 729
top_p: float = 0.95
top_k: int = 1200
class TextToImageInference:
def __init__(self, config: T2IConfig):
self.config = config
self.device = torch.device(config.device)
self._load_models()
self.processor = self.model.get_vision_tower().image_processor
def _load_models(self):
self.model = blip3oQwenForInferenceLMVAE.from_pretrained(
self.config.model_path, torch_dtype=self.config.dtype
).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
def process_image(self, image):
image_size = image.size
image = self.processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
return image, image_size
def preprocess_qwen(self, sources, tokenizer, has_image: bool = True, max_len=2048,
system_message: str = "You are a helpful assistant."):
roles = {"human": "user", "gpt": "assistant"}
if 'image_token_index' not in globals():
tokenizer.add_tokens(["<image>"], special_tokens=True)
global image_token_index
image_token_index = tokenizer.convert_tokens_to_ids("<image>")
im_start, im_end = tokenizer.additional_special_tokens_ids[:2]
unmask_tokens_idx = [198, im_start, im_end]
chat_template = (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
)
tokenizer.chat_template = chat_template
input_ids, targets = [], []
for source in sources:
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
input_id, target = [], []
input_id += tokenizer.apply_chat_template([{"role": "system", "content": system_message}])
target += input_id
for conv in source:
try:
role = conv["role"]
content = conv["content"]
except:
role = conv["from"]
content = conv["value"]
role = roles.get(role, role)
conv = [{"role": role, "content": content}]
encode_id = tokenizer.apply_chat_template(conv)
if role == roles["human"]:
input_id += encode_id
target += encode_id
else:
input_id += encode_id[:-2]
target += encode_id[:-2]
assert len(input_id) == len(target)
for idx, encode_id in enumerate(input_id):
if encode_id in unmask_tokens_idx:
target[idx] = encode_id
if encode_id == image_token_index:
input_id[idx] = IMAGE_TOKEN_INDEX
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return dict(input_ids=input_ids, labels=targets)
def process_target_image(self, image):
image = target_transform(image)
return image
def generate_image(self, prompt: str, image_file: str) -> Image.Image:
image = Image.open(image_file).convert("RGB")
degraded_image = image
image, _ = self.process_image(degraded_image)
detailed_condition = self.process_target_image(degraded_image)
# messages = [
# {"from": "human", "value": "<image>\nPlease reconstruct the given image."},
# {"from": "gpt", "value": f"<im_start><S{self.config.scale}>"}
# ]
messages = [
{"from": "human", "value": f"<image>\nPlease reconstruct the given image based on the image content: {prompt}"},
{"from": "gpt", "value": f"<im_start><S{self.config.scale}>"}
]
data_dict = self.preprocess_qwen([messages], self.tokenizer, has_image=True)
inputs = data_dict['input_ids']
output_image = self.model.generate_images_from_image(
inputs.to(self.device),
images=[image],
detailed_conditions=[detailed_condition],
max_new_tokens=self.config.seq_len,
# do_sample=True,
do_sample=False,
top_p=self.config.top_p,
top_k=self.config.top_k,
)
return degraded_image, output_image[0]
def main():
# Add command line argument parsing
parser = argparse.ArgumentParser(description="Batch image generation with degradation")
parser.add_argument("--model_path", type=str,
default='/data/phd/yaozhengjian/zjYao_Exprs/BLIP-3o-next/Face-Restore_restoration/checkpoint-34640',
help="Path to the model checkpoint")
parser.add_argument("--json_path", type=str,
default="/data/zgq/yaozhengjian/Datasets/FFHQ_val/CelebA_HQ/captions.json",
help="Path to the JSON dataset file")
parser.add_argument("--output_dir", type=str,
default="/data/phd/yaozhengjian/zjYao_Exprs/BLIP-3o-next/Eval/FR-FFHQ-heavy",
help="Output directory for generated images")
args = parser.parse_args()
config = T2IConfig()
config.model_path = args.model_path
inference = TextToImageInference(config)
# === Read JSON file ===
with open(args.json_path, "r") as f:
dataset = json.load(f)
os.makedirs(args.output_dir, exist_ok=True)
# tqdm progress bar
for idx, sample in enumerate(tqdm(dataset, desc="Generating images")):
image_file = sample["image"]
prompt = sample["caption"]
try:
degraded_image, image_sana = inference.generate_image(prompt, image_file)
base_name = os.path.splitext(os.path.basename(sample["image"]))[0]
# Save restored image
save_path = os.path.join(args.output_dir, "restored", f"{base_name}.png")
os.makedirs(os.path.dirname(save_path), exist_ok=True)
image_sana.save(save_path)
# Save degraded image
# degraded_save_path = os.path.join(args.output_dir, "degraded", f"{base_name}.png")
# os.makedirs(os.path.dirname(degraded_save_path), exist_ok=True)
# degraded_image.save(degraded_save_path)
# Print save path
tqdm.write(f"Saved: {save_path}") # Does not disrupt progress bar
# tqdm.write(f"Saved degraded: {degraded_save_path}")
except Exception as e:
tqdm.write(f"Error processing {image_file}: {e}")
if __name__ == "__main__":
main()