DeepRF / envs / simulator.py
simulator.py
Raw
"""code modified from https://github.com/namalkanti/bloch-simulator-python"""
import math
from collections import Iterable
import torch


class BlochSimulator:
    def __init__(self, **kwargs):
        """Bloch simulation for B1, gradients and off-resonance.

            Args::
                gamma: gyromagnetic ratio (Hz/T)
                M0: array of initial magnetization [N, P, 3]
                ts: time duration of each B1 and gradient point (sec) [T]
                T1: T1 relaxation time (sec) [P]
                T2: T2 relaxation time (sec) [P]
                df: array of off-resonance frequencies (Hz) [N]
                dp: array of spatial positions (m) [P, 3]

        """
        # Define constants
        keys = ['gamma', 'M0', 'ts', 'T1', 'T2', 'df', 'dp']
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        for k in keys:
            msg = "BlochSimulator must have '{}' as its argument".format(k)
            v = kwargs.get(k, None)
            assert v is not None, msg

            if isinstance(v, Iterable):
                v = torch.FloatTensor(v).to(self.device)
            self.__dict__.update({k: v})

        # Initialize simulator
        _ = self.reset()

    def reset(self):
        """Initialize magnetization state vector
        """
        self.time = 0.0
        self.idx = 0
        self.M = self.M0.clone().unsqueeze(0)  # [1, N, P, 3]
        return self.M

    def step(self, B1, G):
        """Simulate single time step of the Bloch equation

            Args::
                B1: complex RF pulse for given time t (T) [B, 2]
                G: 3-dimensional gradient (T/m) [B, 3]

            Returns::
                Mt: batch of magnetization at current time step [B, N, P, 3]
                done: boolean flag for denoting the end of simulation

        """
        msg = "arguments must be provided by batch"
        assert B1.ndim == 2 and G.ndim == 2, msg
        assert B1.shape[-1] == 2, "B1 field must be complex"
        assert G.shape[-1] == 3, "spatial gradient must be 3-dimensional"

        # Convert SI units
        ts = self.ts[self.idx]  # (sec)
        gamma = 2 * math.pi * self.gamma  # (rad/sec·T)
        df = 2 * math.pi * self.df  # (rad/sec)
        B1_real = B1[..., 0]  # (T)
        B1_imag = B1[..., 1]  # (T)

        # Rotation angles
        nx = -B1_real[..., None, None] * gamma * ts  # (rad)
        ny = B1_imag[..., None, None] * gamma * ts  # (rad)
        nz = (G.unsqueeze(1) * self.dp.unsqueeze(0) * gamma).sum(-1)  # (rad/sec)
        nz = nz.unsqueeze(1) + df.unsqueeze(0).unsqueeze(-1)  # (rad/sec)
        nz = -nz * ts  # [B, N, P], (rad)

        # Cayley-Klein parameters
        phi = torch.sqrt(nx ** 2 + ny ** 2 + nz ** 2)  # [B, N, P]
        sin_phi = torch.sin(phi / 2) / phi
        # sin_phi[sin_phi != sin_phi] = 0.0  # Convert nans to zero
        cos_phi = torch.cos(phi / 2)
        ar = cos_phi
        ai = -nz * sin_phi
        br = ny * sin_phi
        bi = -nx * sin_phi

        # Rotation matrix
        rot_mat = torch.zeros((*phi.shape, 3, 3), device=self.device)  # [B, N, P, 3, 3]
        rot_mat[..., 0, 0] = ar ** 2 - ai ** 2 - br ** 2 + bi ** 2
        rot_mat[..., 1, 0] = -2 * ar * ai - 2 * br * bi
        rot_mat[..., 2, 0] = -2 * ar * br + 2 * ai * bi
        rot_mat[..., 0, 1] = 2 * ar * ai - 2 * br * bi
        rot_mat[..., 1, 1] = ar ** 2 - ai ** 2 + br ** 2 - bi ** 2
        rot_mat[..., 2, 1] = -2 * ai * br - 2 * ar * bi
        rot_mat[..., 0, 2] = 2 * ar * br + 2 * ai * bi
        rot_mat[..., 1, 2] = 2 * ar * bi - 2 * ai * br
        rot_mat[..., 2, 2] = ar ** 2 + ai ** 2 - br ** 2 - bi ** 2

        Mt = self.M.view(self.M.size(0), -1, 3).unsqueeze(-1)
        Mt = Mt.expand(rot_mat.size(0), -1, -1, -1)
        Mt = Mt.reshape(-1, 3, 1)  # [B x N x P, 3, 1]
        rot_mat = rot_mat.view(-1, 3, 3)  # [B x N x P, 3, 3]
        Mt = torch.bmm(rot_mat, Mt)  # [B x N x P, 3, 1]

        # Time decay matrix
        e1 = torch.exp(-ts / self.T1)
        e2 = torch.exp(-ts / self.T2)
        dec_mat = torch.zeros((len(e1), 3, 3), device=self.device)  # [P, 3, 3]
        dec_mat[:, 0, 0] = e2
        dec_mat[:, 1, 1] = e2
        dec_mat[:, 2, 2] = e1
        dec_vec = torch.zeros((len(e1), 3), device=self.device)  # [P, 3]
        dec_vec[:, 2] = 1 - e1

        size = (phi.size(0) * phi.size(1) * phi.size(2), -1, -1, -1)
        dec_mat = dec_mat.unsqueeze(0).expand(size)
        dec_mat = dec_mat.view(-1, 3, 3)  # [B x N x P, 3, 3]
        Mt = torch.bmm(dec_mat, Mt)  # [B x N x P, 3, 1]
        Mt = Mt.view(*phi.shape, 3)  # [B, N, P, 3]
        dec_vec = dec_vec[None, None, ...]  # [1, 1, P, 3]
        Mt = Mt + dec_vec  # [B, N, P, 3]

        # Step forward and check if done
        self.M = Mt
        self.time += self.ts[self.idx]
        self.idx += 1
        done = (self.idx == len(self.ts))

        return Mt, done