DeepRF / envs / deeprf / environment.py
environment.py
Raw
import gym
import torch
from envs.deeprf.core import SLRExcitation


class DeepRFSLREXC20(gym.Env):
    def __init__(self, **kwargs):
        self.core = SLRExcitation(**kwargs)
        self.input_shape = self.core.input_shape
        self.action_space = self.core.action_space
        self.du = self.core.du
        self.max_amp = self.core.max_amp
        self.df = self.core.df
        self.sar_coef = kwargs.get('sar_coef', 0.0)
        self.ripple_coef = kwargs.get('ripple_coef', 1.0)
        self.max_mag = kwargs.get('max_mag', 0.9833)
        self.max_ripple = kwargs.get('max_ripple', 0.01)

    def __len__(self):
        return self.core.sampling_rate

    def reset(self):
        return self.core.reset()

    def step(self, actions):
        m = actions[..., 0]
        phi = actions[..., 1]
        Mt, done = self.core.step(m, phi)
        amp = ((torch.clamp(m, -1.0, 1.0) + 1.0) * self.core.max_amp * 1e+4 / 2).pow(2) * \
              self.du / self.core.sampling_rate * 1e+6  # (mG)^2sec
        rewards = -self.sar_coef * amp
        if done:
            Mt1 = Mt[:, 0, :1000, :]  # passband
            Mt2 = Mt[:, 0, 1000:, :]  # stopband
            mxy1 = torch.sqrt(torch.sum(Mt1[..., 0], dim=1) ** 2 + torch.sum(Mt1[..., 1], dim=1) ** 2) / 1000
            mxy2 = torch.max(torch.sqrt(Mt2[..., 0] ** 2 + Mt2[..., 1] ** 2), dim=1)[0]
            rewards += torch.clamp(mxy1, max=self.max_mag) - self.ripple_coef * torch.clamp(mxy2 - self.max_ripple, min=0.0)
        return Mt.permute(0, 2, 1, 3), rewards, done