mvq / thirdparty / taming / misc / coord.py
coord.py
Raw
import torch

class CoordStage(object):
    def __init__(self, n_embed, down_factor):
        self.n_embed = n_embed
        self.down_factor = down_factor

    def eval(self):
        return self

    def encode(self, c):
        """fake vqmodel interface"""
        assert 0.0 <= c.min() and c.max() <= 1.0
        b,ch,h,w = c.shape
        assert ch == 1

        c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
                                            mode="area")
        c = c.clamp(0.0, 1.0)
        c = self.n_embed*c
        c_quant = c.round()
        c_ind = c_quant.to(dtype=torch.long)

        info = None, None, c_ind
        return c_quant, None, info

    def decode(self, c):
        c = c/self.n_embed
        c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
                                            mode="nearest")
        return c