RoboFireFuseNet-private / utils / scheduler.py
scheduler.py
Raw
import numpy as np
import math


class PolynomialDecayLR:
    def __init__(self, optimizer, base_lr, max_iters, num_batches, power=0.9, nbb_mult=10, warmup_epochs=0):
        self.optimizer = optimizer
        self.base_lr = base_lr
        self.max_iters = max_iters
        self.power = power
        self.nbb_mult = nbb_mult
        self.cur_iters = 0
        self.cur_epoch = 0
        self.num_batches = num_batches
        self.warmup_epochs = warmup_epochs

    def step(self):
        self.cur_iters += 1 if self.cur_epoch >= self.warmup_epochs else 0
        self.cur_epoch += 1 / self.num_batches
        self._adjust_learning_rate()

    def _adjust_learning_rate(self):
        if self.cur_epoch < self.warmup_epochs:
            lr = self.base_lr * self.cur_epoch / self.warmup_epochs
        else:
            lr = max(np.real(self.base_lr * (1 - float(self.cur_iters) / self.max_iters + 1e-12) ** self.power), 1e-8)
        self.optimizer.param_groups[0]['lr'] = lr
        if len(self.optimizer.param_groups) == 2:
            self.optimizer.param_groups[1]['lr'] = lr * self.nbb_mult
        return lr

    def get_last_lr(self):
        return [group['lr'] for group in self.optimizer.param_groups]
    
    def state_dict(self):
        return {
            'cur_iters': self.cur_iters,
            'cur_epoch': self.cur_epoch
        }

    def load_state_dict(self, state_dict):
        self.cur_iters = state_dict['cur_iters']
        self.cur_epoch = state_dict['cur_epoch']

class CosineDecay:
    def __init__(self, optimizer, base_lr, epochs, num_batches, warmup):
        self.optimizer = optimizer
        self.base_lr = base_lr
        self.max_epochs = epochs
        self.num_batches = num_batches
        self.cur_epoch = 0
        self.warmup_epochs = warmup
        self.nbb_mult = 10

    def step(self):
        self.cur_epoch += 1 / self.num_batches
        self._adjust_learning_rate()

    def _adjust_learning_rate(self):
        if self.cur_epoch < self.warmup_epochs:
            lr = self.base_lr * self.cur_epoch / self.warmup_epochs
        else:
            lr = self.base_lr  * 0.5 * \
            (1. + math.cos(math.pi * (self.cur_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
        self.optimizer.param_groups[0]['lr'] = lr
        if len(self.optimizer.param_groups) == 2:
            self.optimizer.param_groups[1]['lr'] = lr # * self.nbb_mult
        return lr

    def get_last_lr(self):
        return [group['lr'] for group in self.optimizer.param_groups]
    
    def state_dict(self):
        return {
            'cur_epoch': self.cur_epoch
        }

    def load_state_dict(self, state_dict):
        self.cur_epoch = state_dict['cur_epoch']