import os from PIL import Image import torch import ImageReward as RM class ImageRewardScorer(torch.nn.Module): def __init__(self, device="cuda", dtype=torch.float32): super().__init__() self.device = device self.dtype = dtype self.model = ( RM.load( "ImageReward-v1.0", device=device, download_root=os.path.join(os.environ.get("HF_HOME", "~/.cache/"), "ImageReward"), ) .eval() .to(dtype=dtype) ) self.model.requires_grad_(False) @torch.no_grad() def __call__(self, prompts, images): _, rewards = self.model.inference_rank(prompts, images) rewards = torch.diagonal(torch.Tensor(rewards).to(self.device).reshape(len(prompts), len(prompts)), 0) return rewards.contiguous() # Usage example def main(): scorer = ImageRewardScorer(device="cuda", dtype=torch.float32) images = [ "test_cases/nasa.jpg", "test_cases/hello world.jpg", ] pil_images = [Image.open(img) for img in images] prompts = [ 'An astronaut’s glove floating in zero-g with "NASA 2049" on the wrist', 'New York Skyline with "Hello World" written with fireworks on the sky', ] print(scorer(prompts, pil_images)) if __name__ == "__main__": main()