RoboSkate-RL / scripts / python / RoboSkate / VAE-Files / train.py
train.py
Raw
"""
Train a VAE model using saved images in a folder
"""
import argparse
import os

import cv2
import numpy as np
from stable_baselines.common import set_global_seeds
from tqdm import tqdm
import tensorflow as tf

from controller import VAEController
from data_loader import DataLoader
# from torch.utils.data import DataLoader
from model import ConvVAE


def train_vae():
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--folder', help='Path to a folder containing images for training', type=str,
                        default='C:/Users/meric/Desktop/TUM/CBMLR/Repo/G1_RoboSkate/scripts/python/RoboSkateIL/RoboSkateExpertData/directionError/images')
    parser.add_argument('--z-size', help='Latent space', type=int, default=256)
    parser.add_argument('--seed', help='Random generator seed', type=int, default=0)
    parser.add_argument('--n-samples', help='Max number of samples', type=int, default=-1)
    parser.add_argument('--batch-size', help='Batch size', type=int, default=64)
    parser.add_argument('--learning-rate', help='Learning rate', type=float, default=1e-4)
    parser.add_argument('--kl-tolerance', help='KL tolerance (to cap KL loss)', type=float, default=0.5)
    parser.add_argument('--beta', help='Weight for kl loss', type=float, default=1.0)
    parser.add_argument('--n-epochs', help='Number of epochs', type=int, default=50)
    parser.add_argument('--verbose', help='Verbosity', type=int, default=1)
    args = parser.parse_args()

    set_global_seeds(args.seed)

    if not args.folder.endswith('/'):
        args.folder += '/'

    vae = ConvVAE(z_size=args.z_size,
                  batch_size=args.batch_size,
                  learning_rate=args.learning_rate,
                  kl_tolerance=args.kl_tolerance,
                  beta=args.beta,
                  is_training=True,
                  reuse=False)

#    vae_validation = ConvVAE(z_size=args.z_size,
#                             batch_size=args.batch_size,
#                             learning_rate=args.learning_rate,
#                             kl_tolerance=args.kl_tolerance,
#                             beta=args.beta,
#                             is_training=False,
#                             reuse=False)

    images = [im for im in os.listdir(args.folder) if im.endswith('.png')]  # Changed from .jpg
    images = np.array(images)
    n_samples = len(images)

    if args.n_samples > 0:
        n_samples = min(n_samples, args.n_samples)

    print("{} images".format(n_samples))

    # indices for all time steps where the episode continues
    indices = np.arange(n_samples, dtype='int64')
    np.random.shuffle(indices)

    # split indices into minibatches. minibatchlist is a list of lists; each
    # list is the id of the observation preserved through the training
    minibatchlist = [np.array(sorted(indices[start_idx:start_idx + args.batch_size]))
                     for start_idx in range(0, len(indices) - args.batch_size + 1, args.batch_size)]

    data_loader = DataLoader(minibatchlist, images, n_workers=2, folder=args.folder)

    vae_controller = VAEController(z_size=args.z_size)
    vae_controller.vae = vae
#    vae_controller_validation = VAEController(z_size=args.z_size)
#    vae_controller_validation.vae = vae_validation

    writer = tf.summary.FileWriter(
        'C:/Users/meric/Desktop/TUM/CBMLR/Repo/G1_RoboSkate/scripts/python/RoboSkateIL/VAE/logs/tensorboard',
        vae.sess.graph)

    for epoch in range(args.n_epochs):
        pbar = tqdm(total=len(minibatchlist))
        train_loss_avg = []
        r_loss_avg = []
        kl_loss_avg = []

        for obs in data_loader:
            feed = {vae.input_tensor: obs}
            (train_loss, r_loss, kl_loss, train_step, _) = vae.sess.run([
                vae.loss,
                vae.r_loss,
                vae.kl_loss,
                vae.global_step,
                vae.train_op
            ], feed)
            pbar.update(1)
            train_loss_avg.append(train_loss)
            kl_loss_avg.append(kl_loss)
            r_loss_avg.append(r_loss)

        pbar.close()
        train_loss_avg = np.mean(train_loss_avg)
        kl_loss_avg = np.mean(kl_loss_avg)
        r_loss_avg = np.mean(r_loss_avg)

        print("Epoch {:3}/{}".format(epoch + 1, args.n_epochs))
        print("VAE: optimization step", (train_step + 1), train_loss_avg, r_loss_avg, kl_loss_avg)
        summary_train_loss = tf.Summary(value=[tf.Summary.Value(tag="train_loss", simple_value=train_loss_avg)])
        summary_r_loss = tf.Summary(value=[tf.Summary.Value(tag="r_loss", simple_value=r_loss_avg)])
        summary_kl_loss = tf.Summary(value=[tf.Summary.Value(tag="kl_loss", simple_value=kl_loss_avg)])
        writer.add_summary(summary_r_loss, epoch)
        writer.add_summary(summary_kl_loss, epoch)
        writer.add_summary(summary_train_loss, epoch)

        # Update params
        vae_controller.set_target_params()
        # Load test image
        img_save_path = 'C:/Users/meric/Desktop/TUM/CBMLR/Repo/G1_RoboSkate/scripts/python/RoboSkateIL/VAE/logs/images'
        if args.verbose >= 1:
            for idx in range(5):
                image_idx = np.random.randint(n_samples)
                image_path = args.folder + images[image_idx]
                image = cv2.imread(image_path)
                image = cv2.resize(image, (160, 80), interpolation=cv2.INTER_AREA)
                # ROI will not be used.
                # r = ROI
                # im = image[int(r[1]):int(r[1] + r[3]), int(r[0]):int(r[0] + r[2])]

                encoded = vae_controller.encode(image)
                reconstructed_image = vae_controller.decode(encoded)[0]
                # Plot reconstruction
                cv2.imshow("Original", image)
                cv2.imwrite(os.path.join(img_save_path, "Original_Image_Epoch" + str(epoch) + "_" + str(idx) + ".png"),
                            image)
                cv2.imshow("Reconstruct", reconstructed_image)
                cv2.imwrite(
                    os.path.join(img_save_path,
                                 "Reconstructed_Image_Epoch" + str(epoch) + "_" + str(idx) + ".png"),
                    reconstructed_image)
                cv2.waitKey(1)

    save_path = "logs/vae-{}/150Epochs".format(args.z_size)
    os.makedirs(save_path, exist_ok=True)
    print("Saving to {}".format(save_path))
    vae_controller.set_target_params()
    vae_controller.save(save_path)


if __name__ == '__main__':
    train_vae()