mvq / readme.md
readme.md
Raw

PyTorch Implementation of "Multiple Hypothesis Dropout"

Requirements

This codebase was built using Python 3.7 (CUDA 11.3), PyTorch 1.12.0, Torchvision 0.13.0. Use the following script to build a virtual env and install required packages.

python3.7 -m venv $HOME/envs/mvq
source $HOME/envs/mvq/bin/activate
pip install -U pip

pip install torch==1.12.0 torchvision==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install -r requirements.txt

Setup

Default Configuration

In /config/default.yaml, ensure that

  • exp_dir.home -- points to your prefered experiment directory
  • data.path.home -- points to your prefered dataset directory

All other configurations inherit properties from default.yaml.

Datasets

Not all datasets are downloaded automatically by Pytorch. Links to manually download datasets provided below.

Store each of these in the location specified in default.yaml >> data.path.home.

See /dataset/build.py for LMDB constructor for ImageNet datasets.

Precompute FID

Precomputing FID weights should be automatically executed in main.py. However you may run it manually with:

python precompute_fid.py --dataset <DATASET> --data_dir <DATA_DIR>

Training and evaluation

The entry point is main.py which requires specification of mode=['train','eval'], stage=['vae','seq'] and dataset=['mnist', 'kmnist', 'fashion', 'cifar10', 'celeba64', 'imagenet32', 'imagenet64'].

  • stage=="vae" trains the codebook, encoder and decoder.
  • stage=="seq" trains the sequential probabilistic model over the tokens (pixel or transformer).

The codebase is built using Pytorch Distributed. The following is a suitable template for multiple GPUs, please modify according to your environment.

torchrun --nnodes=<NNODES> \
    --nproc_per_node=<TASKS_PER_NODE> \
    --max_restarts=3 \
    --rdzv_id=<ID> \
    --rdzv_backend=c10d \
    --rdzv_endpoint=$MASTER_ADDR \
    main.py --config <CONFIG> --mode <MODE> --stage <STAGE> --dataset <DATASET> 

The following command works for a single GPU. The "--dist" flag can disable Pytorch Distributed.

python main.py --config <CONFIG> --mode <MODE> --stage <STAGE> --dataset <DATASET> --dist 0

See official guide on Pytorch Distributed for more information.

See /bin/train_cifar10.sh for example.

Additional Flags

Flags provided to main.py will override config. A few of these are shown below:

torchrun --nnodes=<NNODES> \
    --nproc_per_node=<TASKS_PER_NODE> \
    --max_restarts=3 \
    --rdzv_id=<ID> \
    --rdzv_backend=c10d \
    --rdzv_endpoint=$MASTER_ADDR \
    main.py -c 32-MVQGAN.yaml -m train -s vae -d cifar10 

Other Configuration Settings

MH-Dropout

MH-Dropout (described in the paper) is implemented in /models/mh_dropout/mhd_random_2d.py. This can be activated in the config files by setting "model.mhd.use" to "true".

VQVAE-2

The VQVAE2 backbone is activated with "model.backbone"=="vq2". MH-Dropout is not compatible with this option, thus will be ignored.

VQGAN

The discriminator can be activated in config with "model.gan.use". VQGAN encoder/decoder is also integrated and can be used by setting "model.coder.name"=="conv_unet".