ARMED-MixedEffectsDL / armed / callbacks / aec_callbacks.py
aec_callbacks.py
Raw
'''
Custom callbacks for autoencoder-classifiers.
'''

import os
import warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import davies_bouldin_score, calinski_harabasz_score

from armed.metrics import image_metrics
from scipy.stats import f_oneway

def make_recon_figure_callback(images: np.array, 
                               model, 
                               output_dir: str, 
                               clusters: np.array=None, 
                               mixedeffects: bool=False):
    """Generate a callback function that produces a figure with example 
    reconstructions. The figure optionally includes the reconstructions 
    with and without cluster-specific effects. The generated function 
    should be used with the LambdaCallback class from Keras to create
    the callback object.

    Args:
        images (np.array): batch of 8 images (8 x h x w x 1)
        model (tf.keras.Model): model        
        output_dir (str): output path
        clusters (np.array): one-hot encoded cluster design matrix if 
            needed by model (8 x n_clusters). Defaults to None
        mixedeffects (bool): include recons w/ and w/o random effects
    """    
    
    import matplotlib.pyplot as plt
    from matplotlib.cm import ScalarMappable
    from matplotlib.colors import Normalize
    
    if mixedeffects:
        
        def _recon_images(epoch, logs):
            # Callback function for saving example reconstruction images after each epoch
            fig, ax = plt.subplots(4, 9, figsize=(9, 4),
                                gridspec_kw={'hspace': 0.3, 'width_ratios': [1] * 8 + [0.2]})  
        
            arrReconME, arrReconFE = model.predict((images, clusters))[:2]
            arrReconDiff = arrReconME - arrReconFE
            vmax = np.abs(arrReconDiff).max()

            for iImg in range(8):
                ax[0, iImg].imshow(images[iImg,], cmap='gray', vmin=0., vmax=1.)
                ax[1, iImg].imshow(arrReconFE[iImg,], cmap='gray', vmin=0., vmax=1.)
                ax[2, iImg].imshow(arrReconME[iImg,], cmap='gray', vmin=0., vmax=1.)
                ax[3, iImg].imshow(arrReconDiff[iImg,], cmap='coolwarm', vmin=-vmax, vmax=vmax)
                
                ax[0, iImg].axis('off')
                ax[1, iImg].axis('off')
                ax[2, iImg].axis('off')
                ax[3, iImg].axis('off')
            
            ax[0, 0].text(-0.2, 0.5, 'Original', transform=ax[0, 0].transAxes, va='center', ha='center', rotation=90)    
            ax[1, 0].text(-0.2, 0.5, 'Recon: FE', transform=ax[1, 0].transAxes, va='center', ha='center', rotation=90)
            ax[2, 0].text(-0.2, 0.5, 'Recon: ME', transform=ax[2, 0].transAxes, va='center', ha='center', rotation=90)    
            ax[3, 0].text(-0.2, 0.5, '(ME - FE)', transform=ax[3, 0].transAxes, va='center', ha='center', rotation=90)
            
            for a in ax[:, -1]:
                a.remove()
            
            axCbar = fig.add_subplot(ax[0, -1].get_gridspec()[:, -1])
            fig.colorbar(ScalarMappable(norm=Normalize(vmin=-vmax, vmax=vmax),
                                        cmap='coolwarm'),
                        cax=axCbar)
            axCbar.set_ylabel('Difference (ME - FE)')
            
            fig.tight_layout(w_pad=0.1, h_pad=0.1)
            fig.savefig(os.path.join(output_dir, f'epoch{epoch+1:03d}.png'))
            plt.close(fig)
    
    else:
        def _recon_images(epoch, logs):
            # Callback function for saving example reconstruction images after each epoch
            fig, ax = plt.subplots(2, 8, figsize=(8, 2))  
        
            if clusters is not None:
                arrRecon = model.predict((images, clusters))[0]
            else:
                arrRecon = model.predict(images)[0]
            
            for iImg in range(8):
                ax[0, iImg].imshow(images[iImg,], cmap='gray', vmin=0., vmax=1.)
                ax[1, iImg].imshow(arrRecon[iImg,], cmap='gray', vmin=0., vmax=1.)
                
                ax[0, iImg].axis('off')
                ax[1, iImg].axis('off')
            
            ax[0, 0].text(-0.2, 0.5, 'Original', transform=ax[0, 0].transAxes, va='center', ha='center', rotation=90)    
            ax[1, 0].text(-0.2, 0.5, 'Recon', transform=ax[1, 0].transAxes, va='center', ha='center', rotation=90)
            
            with warnings.catch_warnings():
                warnings.simplefilter(action='ignore', category=UserWarning)  
                fig.tight_layout(w_pad=0.1, h_pad=0.1)
            fig.savefig(os.path.join(output_dir, f'epoch{epoch+1:03d}.png'))
            plt.close(fig)
            
    return _recon_images

def make_compute_latents_callback(model, images: np.array, image_metadata: pd.DataFrame, output_dir: str):
    """Generate a callback function that calls the encoder on some images
    to create latent representations, then saves them to a .pkl file. The
    function also computes the Davies-Bouldin and Calinski-Harabasz clustering
    metrics on the latents and logs the results to a file. The generated
    function should be used with the LambdaCallback class from Keras to create
    the callback object.

    Args:
        model (tf.keras.Model): encoder model
        images (np.array): batch of 8 images (8 x h x w x 1)
        image_metadata (pd.DataFrame): metadata table
        output_dir (str): output path
    """ 

    def _compute_latents(epoch, logs):
        # callback function for computing latent reps for all training images and saving to a pkl file
        arrLatents = model.predict(images)
        dfLatents = pd.DataFrame(arrLatents, index=image_metadata['image'].values)
        dfLatents.to_pickle(os.path.join(output_dir, f'epoch{epoch+1:03d}_latents.pkl'))
        
        db = davies_bouldin_score(dfLatents, image_metadata['date'])
        ch = calinski_harabasz_score(dfLatents, image_metadata['date'])

        print(f'\nClustering scores:'
            f'\n\tDavies-Bouldin (higher is better): {db}'
            f'\n\tCalinski-Harabasz (lower is better): {ch}'
        )
        
        # Append to file
        with open(os.path.join(output_dir, 'clustering_scores.csv'), 'a') as f:
            if epoch == 0:
                f.write('epoch,DB,CH\n')
            f.write(f'{epoch+1},{db},{ch}\n')
            
        
    return _compute_latents
    
def compute_image_metrics(epoch: int, model, data_in, metadata: pd.DataFrame, 
                          output_dir: str, output_idx: int=0):
    """Compute image metrics including brightness, contrast, sharpness, and SNR. 
    Also create histograms comparing distributions of these metrics across clusters.

    Args:
        epoch (int): epoch number
        model (tf.keras.Model): model
        data_in (np.array or tuple of arrays): input data
        metadata (pd.DataFrame): image metadata
        output_dir (str): path to output location
        output_idx (int, optional): Index of model outputs containing the image 
            outputs. Defaults to 0.

    Returns:
        [type]: [description]
    """    
    
    lsRecons = []
    if isinstance(data_in, tuple):
        nImages = data_in[0].shape[0]
    else:
        nImages = data_in.shape[0]
    nBatches = int(np.ceil(nImages / 1000))

    for iBatch in range(nBatches):
        iStart = 1000 * iBatch
        iEnd = np.min([1000 * (iBatch + 1), nImages])
        
        if isinstance(data_in, tuple):
            batch_in = (data_in[0][iStart:iEnd,], data_in[1][iStart:iEnd,])
        else:
            batch_in = data_in[iStart:iEnd,]
            
        arrRecons = model.predict(batch_in, batch_size=32)[output_idx]        
        lsRecons += [arrRecons]
        
    arrRecons = np.concatenate(lsRecons, axis=0)

    lsMetrics = [image_metrics(img) for img in arrRecons]

    dfMetrics = pd.DataFrame(lsMetrics)
    dfMetrics.index = metadata.index

    # dictDates = {160802: 'Day 1',
    #             160808: 'Day 2',
    #             161209: 'Day 3',
    #             161214: 'Day 4',
    #             161220: 'Day 5',
    #             161224: 'Day 6'}
    
    # dfMetrics['Date'] = metadata['date'].apply(lambda x: dictDates[x]).values
    dfMetrics['Date'] = metadata['date']
    
    dictMetricNames = {'Brightness': 'Mean brightness',
                       'Contrast': 'Contrast (s.d.)',
                       'Sharpness': 'Sharpness (variance-of-Laplacian)',
                       'SNR': 'Signal-to-noise ratio'}

    dictFstats = {}
    fig, ax = plt.subplots(4, 1, figsize=(16, 13), gridspec_kw={'hspace': 0.4})
    for i, (strMetric, strAxisLabel) in enumerate(dictMetricNames.items()):
        vmax = dfMetrics[strMetric].quantile(0.999)
        vmin = dfMetrics[strMetric].min()
        sns.histplot(data=dfMetrics[(dfMetrics[strMetric] >= vmin) & (dfMetrics[strMetric] <= vmax)], 
                     x=strMetric, hue='Date', ax=ax[i], stat='density', bins=100)
        ax[i].set_xlabel(strAxisLabel)
        
        lsGroups = [dfMetrics[strMetric].loc[dfMetrics['Date'] == d].values for d in dfMetrics['Date'].unique()]
        f, p = f_oneway(*lsGroups)
        dictFstats[strMetric] = f
        
    fig.savefig(os.path.join(output_dir, f'epoch{epoch:03d}_recon_image_metrics.svg'))
    plt.close(fig)
    return dictFstats
    
def make_image_metrics_callback(model, data_in, metadata, output_dir, output_idx=0):
    """Generate a callback function that computes image metrics including 
    brightness, contrast, sharpness, and SNR. The generated function should be
    used with the LambdaCallback class from Keras to create the callback object.
    
    Args:
        model (tf.keras.Model): model
        data_in (np.array or tuple of arrays): input data
        metadata (pd.DataFrame): image metadata
        output_dir (str): path to output location
        output_idx (int, optional): Index of model outputs containing the image 
            outputs. Defaults to 0.
    """    
    def _fn(epoch , logs):
        metrics = compute_image_metrics(epoch+1, model, data_in, metadata, output_dir, output_idx=output_idx)
        print(metrics)

        # Append to file
        metrics['Epoch'] = epoch + 1
        lsKeys = ['Epoch', 'Brightness', 'Contrast', 'Sharpness', 'SNR']
        with open(os.path.join(output_dir, 'image_metrics_fstat.csv'), 'a') as f:
            if epoch == 0:
                f.write(','.join(lsKeys) + '\n')
            f.write(','.join([str(metrics[k]) for k in lsKeys]) + '\n')

    return _fn