ARMED-MixedEffectsDL / armed / callbacks / segmentation.py
segmentation.py
Raw
'''
Custom Keras callbacks for segmentation models (currently unused)
'''
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

class SaveMultiModalImagesCallback(tf.keras.callbacks.Callback):
    def  __init__(self, data, arrYVal, strSaveDir, random_effects=False):
        """Callback for saving some example segmentations as images at the end of each epoch

        Images will contain a 4x4 grid. Each row contains a random validation sample and each 
        column contains a channel/modality. Ground truth masks are overlaid in red while 
        predictions are overlaid in blue.
        Args:
            data (np.array or tuple): 4D array of validation data or tuple of (data, design_matrix)
            arrYVal (np.array): 4D array of validation labels (segmentation masks)
            strSaveDir (str): path to directory for saving images
        """        
        super(SaveMultiModalImagesCallback, self).__init__()
        self.data = data
        self.arrYVal = arrYVal
        self.strSaveDir = strSaveDir
        self.random_effects = random_effects
    def on_epoch_end(self, epoch, logs=None):
        fig, ax = plt.subplots(4, 4, dpi=150)
        # Create a grid where each row is a different sample and each column 
        # is a different modality within that sample
        for i in range(4):
            np.random.seed(i * 2348)
            if self.random_effects:
                k = np.random.randint(self.data[0].shape[0])
                arrInput = self.data[0][k,]
                arrDesign = self.data[1][k,]
                arrPredMask = self.model.predict((np.expand_dims(arrInput, 0),
                                                  np.expand_dims(arrDesign, 0))).squeeze()
            else:
                k = np.random.randint(self.data.shape[0])
                arrInput = self.data[k,] 
                arrPredMask = self.model.predict(np.expand_dims(arrInput, 0)).squeeze()
            arrTrueMask = self.arrYVal[k,].squeeze()
            for j in range(4):
                ax[i, j].imshow(arrInput[:, :, j], cmap='Greys_r')

                arrTrueOverlay = np.zeros(arrTrueMask.shape + (4,))
                arrTrueOverlay[..., 0] = arrTrueMask
                arrTrueOverlay[..., -1] = arrTrueMask
                ax[i, j].imshow(arrTrueOverlay, alpha=0.3)

                arrPredOverlay = np.zeros(arrTrueMask.shape + (4,))
                arrPredOverlay[..., 2] = (arrPredMask >= 0.5)
                arrPredOverlay[..., -1] = (arrPredMask)
                ax[i, j].imshow(arrPredOverlay, alpha=0.3)
                ax[i, j].axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(self.strSaveDir, 'epoch{:03d}.png'.format(epoch)))
        plt.close()


class SaveImagesCallback(tf.keras.callbacks.Callback):
    def  __init__(self, data, arrYVal, strSaveDir, random_effects=False):
        """Callback for saving some example segmentations as images at the end of each epoch

        Args:
            data (np.array or tuple): 4D array of validation data or tuple of (data, design_matrix)
            arrYVal (np.array): 4D array of validation labels (segmentation masks)
            strSaveDir (str): path to directory for saving images
        """        
        super(SaveImagesCallback, self).__init__()
        self.data = data
        self.arrYVal = arrYVal
        self.strSaveDir = strSaveDir
        self.random_effects = random_effects
    def on_epoch_end(self, epoch, logs=None):
        fig, ax = plt.subplots(3, 3, dpi=150)
        # Create a grid where each row is a different sample and each column 
        # is a different modality within that sample
        for i in range(9):
            if self.random_effects:
                arrInput = self.data[0][i,]
                arrDesign = self.data[1][i,]
                arrPredMask = self.model.predict((np.expand_dims(arrInput, 0),
                                                  np.expand_dims(arrDesign, 0))).squeeze()
            else:
                arrInput = self.data[i,] 
                arrPredMask = self.model.predict(np.expand_dims(arrInput, 0)).squeeze()
            arrTrueMask = self.arrYVal[i,].squeeze()
            ax.flatten()[i].imshow(arrInput.squeeze(), cmap='Greys_r')

            arrTrueOverlay = np.zeros(arrTrueMask.shape + (4,))
            arrTrueOverlay[..., 0] = arrTrueMask
            arrTrueOverlay[..., -1] = arrTrueMask
            ax.flatten()[i].imshow(arrTrueOverlay, alpha=0.2)

            arrPredOverlay = np.zeros(arrTrueMask.shape + (4,))
            arrPredOverlay[..., 2] = (arrPredMask >= 0.5)
            arrPredOverlay[..., -1] = (arrPredMask)
            ax.flatten()[i].imshow(arrPredOverlay, alpha=0.3)
            ax.flatten()[i].axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(self.strSaveDir, 'epoch{:03d}.png'.format(epoch)))
        plt.close()