Pref-Restoration / blip3o / data / image_degradation.py
image_degradation.py
Raw
import cv2
import math
import random
import numpy as np
import torch
from PIL import Image
import io
from . import gaussian_kernels


def degrade_image(img_gt, 
                  gt_size=512,
                  in_size=512,
                  use_motion_kernel=False,
                  motion_kernels=None,
                  motion_kernel_prob=0.001,
                  kernel_list=['iso', 'aniso'],
                  kernel_prob=[0.5, 0.5],
                  blur_kernel_size=41,
                  blur_sigma=[1, 15],
                  downsample_range=[4, 30],
                  noise_range=[0, 20],
                  jpeg_range=[30, 80]):
    

    img_array = np.array(img_gt)
    if len(img_array.shape) == 3 and img_array.shape[2] == 3:
        img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
    img_in = img_array.astype(np.float32) / 255.0
    

    if use_motion_kernel and motion_kernels is not None and random.random() < motion_kernel_prob:
        m_i = random.randint(0, 31)
        k = motion_kernels[f'{m_i:02d}']
        img_in = cv2.filter2D(img_in, -1, k)
    

    kernel = gaussian_kernels.random_mixed_kernels(
        kernel_list,
        kernel_prob,
        blur_kernel_size,
        blur_sigma,
        blur_sigma, 
        [-math.pi, math.pi],
        noise_range=None)
    img_in = cv2.filter2D(img_in, -1, kernel)


    scale = np.random.uniform(downsample_range[0], downsample_range[1])
    img_in = cv2.resize(img_in, (int(gt_size // scale), int(gt_size // scale)), interpolation=cv2.INTER_LINEAR)

    if noise_range is not None:
        noise_sigma = np.random.uniform(noise_range[0] / 255., noise_range[1] / 255.)
        noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
        img_in = img_in + noise
        img_in = np.clip(img_in, 0, 1)


    if jpeg_range is not None:
        jpeg_p = np.random.uniform(jpeg_range[0], jpeg_range[1])
        encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
        _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
        img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.


    img_in = cv2.resize(img_in, (in_size, in_size), interpolation=cv2.INTER_LINEAR)
    

    img_in = np.clip(img_in * 255.0, 0, 255).astype(np.uint8)
    if len(img_in.shape) == 3 and img_in.shape[2] == 3:
        img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2RGB)
    
    return Image.fromarray(img_in)


def load_motion_kernels(motion_kernel_path='basicsr/data/motion-blur-kernels-32.pth'):
    return torch.load(motion_kernel_path)


class ImageDegradationConfig:

    def __init__(self,
                 gt_size=512,
                 in_size=512,
                 use_motion_kernel=False,
                 motion_kernel_path='basicsr/data/motion-blur-kernels-32.pth',
                 motion_kernel_prob=0.001,
                 kernel_list=['iso', 'aniso'],
                 kernel_prob=[0.5, 0.5],
                 blur_kernel_size=41,
                 blur_sigma=[1, 15],
                 downsample_range=[4, 30],
                 noise_range=[0, 20],
                 jpeg_range=[30, 80]):
        
        self.gt_size = gt_size
        self.in_size = in_size
        self.use_motion_kernel = use_motion_kernel
        self.motion_kernel_prob = motion_kernel_prob
        self.kernel_list = kernel_list
        self.kernel_prob = kernel_prob
        self.blur_kernel_size = blur_kernel_size
        self.blur_sigma = blur_sigma
        self.downsample_range = downsample_range
        self.noise_range = noise_range
        self.jpeg_range = jpeg_range
        
        self.motion_kernels = None
        if self.use_motion_kernel:
            self.motion_kernels = load_motion_kernels(motion_kernel_path)
    
    def degrade(self, img_gt):

        return degrade_image(
            img_gt=img_gt,
            gt_size=self.gt_size,
            in_size=self.in_size,
            use_motion_kernel=self.use_motion_kernel,
            motion_kernels=self.motion_kernels,
            motion_kernel_prob=self.motion_kernel_prob,
            kernel_list=self.kernel_list,
            kernel_prob=self.kernel_prob,
            blur_kernel_size=self.blur_kernel_size,
            blur_sigma=self.blur_sigma,
            downsample_range=self.downsample_range,
            noise_range=self.noise_range,
            jpeg_range=self.jpeg_range
        )



if __name__ == "__main__":

    config = ImageDegradationConfig(
        gt_size=512,
        in_size=512,
        use_motion_kernel=False,
        motion_kernels=None,
        kernel_list=['iso', 'aniso'],
        kernel_prob=[0.5, 0.5],
        blur_kernel_size=41,
        blur_sigma=[1, 15],
        downsample_range=[4, 30],
        noise_range=[0, 20],
        jpeg_range=[30, 80],
    )
    img_degraded = config.degrade(img_gt)
    pass