This codebase was built using Python 3.7 (CUDA 11.3), PyTorch 1.12.0, Torchvision 0.13.0. Use the following script to build a virtual env and install required packages.
python3.7 -m venv $HOME/envs/mvq
source $HOME/envs/mvq/bin/activate
pip install -U pip
pip install torch==1.12.0 torchvision==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install -r requirements.txt
In /config/default.yaml, ensure that
All other configurations inherit properties from default.yaml.
Not all datasets are downloaded automatically by Pytorch. Links to manually download datasets provided below.
Store each of these in the location specified in default.yaml >> data.path.home.
See /dataset/build.py for LMDB constructor for ImageNet datasets.
Precomputing FID weights should be automatically executed in main.py. However you may run it manually with:
python precompute_fid.py --dataset <DATASET> --data_dir <DATA_DIR>
The entry point is main.py which requires specification of mode=['train','eval'], stage=['vae','seq'] and dataset=['mnist', 'kmnist', 'fashion', 'cifar10', 'celeba64', 'imagenet32', 'imagenet64'].
The codebase is built using Pytorch Distributed. The following is a suitable template for multiple GPUs, please modify according to your environment.
torchrun --nnodes=<NNODES> \
--nproc_per_node=<TASKS_PER_NODE> \
--max_restarts=3 \
--rdzv_id=<ID> \
--rdzv_backend=c10d \
--rdzv_endpoint=$MASTER_ADDR \
main.py --config <CONFIG> --mode <MODE> --stage <STAGE> --dataset <DATASET>
The following command works for a single GPU. The "--dist" flag can disable Pytorch Distributed.
python main.py --config <CONFIG> --mode <MODE> --stage <STAGE> --dataset <DATASET> --dist 0
See official guide on Pytorch Distributed for more information.
See /bin/train_cifar10.sh for example.
Flags provided to main.py will override config. A few of these are shown below:
torchrun --nnodes=<NNODES> \
--nproc_per_node=<TASKS_PER_NODE> \
--max_restarts=3 \
--rdzv_id=<ID> \
--rdzv_backend=c10d \
--rdzv_endpoint=$MASTER_ADDR \
main.py -c 32-MVQGAN.yaml -m train -s vae -d cifar10
MH-Dropout (described in the paper) is implemented in /models/mh_dropout/mhd_random_2d.py. This can be activated in the config files by setting "model.mhd.use" to "true".
The VQVAE2 backbone is activated with "model.backbone"=="vq2". MH-Dropout is not compatible with this option, thus will be ignored.
The discriminator can be activated in config with "model.gan.use". VQGAN encoder/decoder is also integrated and can be used by setting "model.coder.name"=="conv_unet".