ARMED-MixedEffectsDL / armed / models / cnn_classifier.py
cnn_classifier.py
Raw
'''
Simple 2D CNN classifiers, including domain adversarial and mixed effects extensions.
'''
import tensorflow as tf
import tensorflow.keras.layers as tkl
from .random_effects import RandomEffects

class ImageClassifier(tf.keras.Model):
    '''
    Simple 2D image binary classifier with 7 convolution blocks and 2 final dense layers.
    '''
    
    def __init__(self, name='classifier', **kwargs):
        """Simple 2D image binary classifier with 7 convolution blocks and 2 final dense layers.

        Args:
            name (str, optional): Model name. Defaults to 'classifier'.
        """        
        
        super(ImageClassifier, self).__init__(name=name, **kwargs)
        
        # D x D
        self.conv0 = tkl.Conv2D(64, 3, padding='same', name='conv0')
        self.bn0 = tkl.BatchNormalization(name='bn0')
        self.prelu0 = tkl.PReLU(name='prelu0')
        self.maxpool0 = tkl.MaxPool2D(name='maxpool0')
        # D/2 x D/2        
        self.conv1 = tkl.Conv2D(128, 3, padding='same', name='conv1')
        self.dropout1 = tkl.Dropout(0.5, name='dropout1')
        self.bn1 = tkl.BatchNormalization(name='bn1')
        self.prelu1 = tkl.PReLU(name='prelu1')
        self.maxpool1 = tkl.MaxPool2D(name='maxpool1')
        # D/4 x D/4
        self.conv2 = tkl.Conv2D(128, 3, padding='same', name='conv2')
        self.bn2 = tkl.BatchNormalization(name='bn2')
        self.prelu2 = tkl.PReLU(name='prelu2')
        self.maxpool2 = tkl.MaxPool2D(name='maxpool2')
        # D/8 x D/8
        self.conv3 = tkl.Conv2D(256, 3, padding='same', name='conv3')
        self.dropout3 = tkl.Dropout(0.5, name='dropout3')
        self.bn3 = tkl.BatchNormalization(name='bn3')
        self.prelu3 = tkl.PReLU(name='prelu3')
        self.maxpool3 = tkl.MaxPool2D(name='maxpool3')
        # D/16 x D/16
        self.conv4 = tkl.Conv2D(256, 3, padding='same', name='conv4')
        self.bn4 = tkl.BatchNormalization(name='bn4')
        self.prelu4 = tkl.PReLU(name='prelu4')
        self.maxpool4 = tkl.MaxPool2D(name='maxpool4')
        # D/32 x D/32
        self.conv5 = tkl.Conv2D(512, 3, padding='same', name='conv5')
        self.dropout5 = tkl.Dropout(0.5, name='dropout5')
        self.bn5 = tkl.BatchNormalization(name='bn5')
        self.prelu5 = tkl.PReLU(name='prelu5')
        self.maxpool5 = tkl.MaxPool2D(name='maxpool5')
        # # D/64 x D/64
        self.conv6 = tkl.Conv2D(512, 3, padding='valid', name='conv6')
        self.prelu6 = tkl.PReLU(name='prelu6')
        self.flatten = tkl.Flatten(name='flatten')
        self.dense = tkl.Dense(512, name='dense', activation='relu',
                               kernel_regularizer=tf.keras.regularizers.L1L2(l1=0.01))
        self.out = tkl.Dense(1, name='output')
        self.sigmoid = tkl.Activation('sigmoid', name='sigmoid')
                
        
    def call(self, inputs, return_layer_activations=False):
        c0 = self.conv0(inputs)
        c0 = self.bn0(c0)
        c0 = self.prelu0(c0)
        
        c1 = self.maxpool0(c0)
        c1 = self.conv1(c1)
        c1 = self.dropout1(c1)
        c1 = self.bn1(c1)
        c1 = self.prelu1(c1)
        
        c2 = self.maxpool1(c1)
        c2 = self.conv2(c2)
        c2 = self.bn2(c2)
        c2 = self.prelu2(c2)
        
        c3 = self.maxpool2(c2)
        c3 = self.conv3(c3)
        c3 = self.dropout3(c3)
        c3 = self.bn3(c3)
        c3 = self.prelu3(c3)
        
        c4 = self.maxpool3(c3)
        c4 = self.conv4(c4)
        c4 = self.bn4(c4)
        c4 = self.prelu4(c4)
        
        c5 = self.maxpool4(c4)
        c5 = self.conv5(c5)
        c5 = self.dropout5(c5)
        c5 = self.bn5(c5)
        c5 = self.prelu5(c5)
        
        c6 = self.maxpool5(c5)
        c6 = self.conv6(c6)
        c6 = self.prelu6(c6)
        h = self.flatten(c6)
        h = self.dense(h)
        y = self.out(h)
        y = self.sigmoid(y)
        if return_layer_activations:
            return c0, c1, c2, c3, c4, c5, c6, h, y
        else:
            return y

class ClusterInputImageClassifier(ImageClassifier):
    '''
    Simple 2D image classifier that takes an additional input containing the
    one-hot encoded cluster membership of each sample. This input is
    concatenated to the flattened output of the last convolution layer, before
    the first dense layer.
    '''
    def call(self, inputs, return_layer_activations=False):
        x, z = inputs
        c0 = self.conv0(x)
        c0 = self.bn0(c0)
        c0 = self.prelu0(c0)
        
        c1 = self.maxpool0(c0)
        c1 = self.conv1(c1)
        c1 = self.dropout1(c1)
        c1 = self.bn1(c1)
        c1 = self.prelu1(c1)
        
        c2 = self.maxpool1(c1)
        c2 = self.conv2(c2)
        c2 = self.bn2(c2)
        c2 = self.prelu2(c2)
        
        c3 = self.maxpool2(c2)
        c3 = self.conv3(c3)
        c3 = self.dropout3(c3)
        c3 = self.bn3(c3)
        c3 = self.prelu3(c3)
        
        c4 = self.maxpool3(c3)
        c4 = self.conv4(c4)
        c4 = self.bn4(c4)
        c4 = self.prelu4(c4)
        
        c5 = self.maxpool4(c4)
        c5 = self.conv5(c5)
        c5 = self.dropout5(c5)
        c5 = self.bn5(c5)
        c5 = self.prelu5(c5)
        
        c6 = self.maxpool5(c5)
        c6 = self.conv6(c6)
        c6 = self.prelu6(c6)
        h = self.flatten(c6)
        h = tf.concat([h, z], axis=1)
        h = self.dense(h)
        y = self.out(h)
        y = self.sigmoid(y)
        if return_layer_activations:
            return c0, c1, c2, c3, c4, c5, c6, h, y
        else:
            return y
        
class AdversarialClassifier(tf.keras.Model):
    '''
    Domain adversarial classifier for the ImageClassifier. Receives the
    layer activations from the ImageClassifier as inputs and predicts
    cluster membership.
    '''
    def __init__(self, n_clusters, name='adversary', **kwargs):
        """Domain adversarial classifier for the ImageClassifier. Receives the
        layer activations from the ImageClassifier as inputs and predicts
        cluster membership.

        Args: 
            n_clusters (int): number of clusters 
            name (str, optional): Model name. Defaults to 'adversary'.
        """        
        super(AdversarialClassifier, self).__init__(name=name, **kwargs)
        self.n_clusters = n_clusters
        
        # D x D
        self.conv0 = tkl.Conv2D(64, 3, padding='same', name='conv0')
        self.bn0 = tkl.BatchNormalization(name='bn0')
        self.prelu0 = tkl.PReLU(name='prelu0')
        self.maxpool0 = tkl.MaxPool2D(name='maxpool0')
        # D/2 x D/2        
        self.conv1 = tkl.Conv2D(64, 3, padding='same', name='conv1')
        self.dropout1 = tkl.Dropout(0.5, name='dropout1')
        self.bn1 = tkl.BatchNormalization(name='bn1')
        self.prelu1 = tkl.PReLU(name='prelu1')
        self.maxpool1 = tkl.MaxPool2D(name='maxpool1')
        # D/4 x D/4
        self.conv2 = tkl.Conv2D(128, 3, padding='same', name='conv2')
        self.bn2 = tkl.BatchNormalization(name='bn2')
        self.prelu2 = tkl.PReLU(name='prelu2')
        self.maxpool2 = tkl.MaxPool2D(name='maxpool2')
        # D/8 x D/8
        self.conv3 = tkl.Conv2D(128, 3, padding='same', name='conv3')
        self.dropout3 = tkl.Dropout(0.5, name='dropout3')
        self.bn3 = tkl.BatchNormalization(name='bn3')
        self.prelu3 = tkl.PReLU(name='prelu3')
        self.maxpool3 = tkl.MaxPool2D(name='maxpool3')
        # D/16 x D/16
        self.conv4 = tkl.Conv2D(256, 3, padding='same', name='conv4')
        self.bn4 = tkl.BatchNormalization(name='bn4')
        self.prelu4 = tkl.PReLU(name='prelu4')
        self.maxpool4 = tkl.MaxPool2D(name='maxpool4')
        # D/32 x D/32
        self.conv5 = tkl.Conv2D(256, 3, padding='same', name='conv5')
        self.dropout5 = tkl.Dropout(0.5, name='dropout5')
        self.bn5 = tkl.BatchNormalization(name='bn5')
        self.prelu5 = tkl.PReLU(name='prelu5')
        self.maxpool5 = tkl.MaxPool2D(name='maxpool5')
        # # D/64 x D/64
        self.conv6 = tkl.Conv2D(512, 3, padding='valid', name='conv6')
        self.prelu6 = tkl.PReLU(name='prelu6')
        self.flatten = tkl.Flatten(name='flatten')
        self.dense = tkl.Dense(512, name='dense', activation='relu')
        self.out = tkl.Dense(n_clusters, name='output', activation='softmax')
        
    def call(self, inputs):
        c0, c1, c2, c3, c4, c5, c6, h = inputs
        
        x = self.conv0(c0)
        x = self.bn0(x)
        x = self.prelu0(x)
        
        x = self.maxpool0(x)
        x = tf.concat([x, c1], axis=-1)
        x = self.conv1(x)
        x = self.dropout1(x)
        x = self.bn1(x)
        x = self.prelu1(x)
        
        x = self.maxpool1(x)
        x = tf.concat([x, c2], axis=-1)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.prelu2(x)
        
        x = self.maxpool2(x)
        x = tf.concat([x, c3], axis=-1)
        x = self.conv3(x)
        x = self.dropout3(x)
        x = self.bn3(x)
        x = self.prelu3(x)
        
        x = self.maxpool3(x)
        x = tf.concat([x, c4], axis=-1)
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.prelu4(x)
        
        x = self.maxpool4(x)
        x = tf.concat([x, c5], axis=-1)
        x = self.conv5(x)
        x = self.dropout5(x)
        x = self.bn5(x)
        x = self.prelu5(x)
        
        x = self.maxpool5(x)
        # Don't concatenate c6 because the tensor shapes don't line up
        x = self.conv6(x)
        x = self.prelu6(x)
        x = self.flatten(x)
        x = tf.concat([x, h], axis=-1)
        x = self.dense(x)
        x = self.out(x)
        return x
        
        
class DomainAdversarialImageClassifier(tf.keras.Model):
    '''
    Domain adversarial 2D image classifier which learns the classification 
    task while competing with an adversary, which learns to predict cluster 
    membership from the classifier's layer activations.
    '''
    
    def __init__(self, n_clusters, name='da_classifier', **kwargs):
        """Domain adversarial 2D image classifier which learns the classification 
        task while competing with an adversary, which learns to predict cluster 
        membership from the classifier's layer activations.

        Args:
            n_clusters (int): number of clusters
            name (str, optional): Model name. Defaults to 'da_classifier'.
        """        
        super(DomainAdversarialImageClassifier, self).__init__(name=name, **kwargs)
        self.classifier = ImageClassifier()
        self.adversary = AdversarialClassifier(n_clusters)
        self.n_clusters = n_clusters
        
    def call(self, inputs):
        x, z, = inputs
        y_pred = self.classifier(x)
        
        return y_pred
    
    def compile(self,
                loss_classifier=tf.keras.losses.BinaryCrossentropy(),
                loss_adversary=tf.keras.losses.CategoricalCrossentropy(),
                loss_classifier_weight=1.0,
                loss_gen_weight=0.1,
                metric_classifier=tf.keras.metrics.AUC(),
                opt_classifier=tf.keras.optimizers.Nadam(lr=0.0001),
                opt_adversary=tf.keras.optimizers.Nadam(lr=0.0001)
                ):
        """Compile model. Must be called before training or loading weights.

        Args:
            loss_classifier (loss, optional): Main classification loss. 
                Defaults to tf.keras.losses.BinaryCrossentropy().
            loss_adversary (loss, optional): Adversary classification loss. 
                Defaults to tf.keras.losses.CategoricalCrossentropy().
            loss_classifier_weight (float, optional): Main classification loss weight. 
                Defaults to 1.0.
            loss_gen_weight (float, optional): Generalization loss weight. 
                Defaults to 0.01.
            metric_classifier (metric, optional): Main classification metric. 
                Defaults to tf.keras.metrics.AUC().
            opt_classifier (optimizer, optional): Optimizer for main classifier. 
                Defaults to tf.keras.optimizers.Nadam(lr=0.0001).
            opt_adversary (optimizer, optional): Optimizer for adversary. 
                Defaults to tf.keras.optimizers.Nadam(lr=0.0001).
        """        
        super().compile()
        
        self.loss_classifier = loss_classifier
        self.loss_adversary = loss_adversary
        self.loss_classifier_weight = loss_classifier_weight
        self.loss_gen_weight = loss_gen_weight
        self.metric_classifier = metric_classifier
        self.opt_classifier = opt_classifier
        self.opt_adversary = opt_adversary
        
        self.loss_classifier_tracker = tf.keras.metrics.Mean(name='class_loss')
        self.loss_gen_tracker = tf.keras.metrics.Mean(name='gen_loss')
        self.loss_adversary_tracker = tf.keras.metrics.Mean(name='adv_loss')
        self.loss_total_tracker = tf.keras.metrics.Mean(name='total_loss')
        
    @property
    def metrics(self):
        return [self.loss_classifier_tracker,
                self.loss_gen_tracker,
                self.loss_adversary_tracker,
                self.loss_total_tracker,
                self.metric_classifier]
        
    def _compute_update_loss(self, loss_class, loss_gen):
        '''Compute total loss and update loss running means'''
        self.loss_classifier_tracker.update_state(loss_class)
        self.loss_gen_tracker.update_state(loss_gen)
        
        loss_total = (self.loss_classifier_weight * loss_class) + (self.loss_gen_weight * loss_gen)
        self.loss_total_tracker.update_state(loss_total)
        
        return loss_total
    
    def train_step(self, data):
        if len(data) == 3:
            (images, clusters), labels, sample_weights = data
        else:
            (images, clusters), labels = data
            sample_weights = None
            
        # train adversary
        with tf.GradientTape() as gt:
            clf_outs = self.classifier(images, return_layer_activations=True)
            layer_activations = clf_outs[:-1]
            clusters_pred = self.adversary(layer_activations)
            loss_adv = self.loss_adversary(clusters, clusters_pred, sample_weight=sample_weights)

        grads_adv = gt.gradient(loss_adv, self.adversary.trainable_weights)
        self.opt_adversary.apply_gradients(zip(grads_adv, self.adversary.trainable_weights))
        self.loss_adversary_tracker.update_state(loss_adv)
        
        # train main classifier
        with tf.GradientTape() as gt2:
            clf_outs = self.classifier(images, return_layer_activations=True)
            y_pred = clf_outs[-1]
            loss_class = self.loss_classifier(tf.reshape(labels, (-1, 1)), 
                                              tf.reshape(y_pred, (-1, 1)), 
                                              sample_weight=sample_weights)
                        
            layer_activations = clf_outs[:-1]
            clusters_pred = self.adversary(layer_activations)
            loss_gen = self.loss_adversary(clusters, clusters_pred, sample_weight=sample_weights)
            
            loss_total = self._compute_update_loss(loss_class, loss_gen)
            
        grads = gt2.gradient(loss_total, self.classifier.trainable_weights)
        self.opt_classifier.apply_gradients(zip(grads, self.classifier.trainable_weights))
        
        self.metric_classifier.update_state(labels, y_pred)
        return {m.name: m.result() for m in self.metrics}
    
    def test_step(self, data):
        (images, clusters), labels = data
        clf_outs = self.classifier(images, return_layer_activations=True)
        y_pred = clf_outs[-1]
        loss_class = self.loss_classifier(labels, y_pred)
        
        layer_activations = clf_outs[:-1]
        clusters_pred = self.adversary(layer_activations)
        loss_gen = self.loss_adversary(clusters, clusters_pred)
            
        _ = self._compute_update_loss(loss_class, loss_gen)
        
        self.metric_classifier.update_state(labels, y_pred)
        return {m.name: m.result() for m in self.metrics}
    
class RandomEffectsClassifier(ImageClassifier):
    
    def __init__(self, 
                 slope_post_init_scale=0.1,
                 intercept_post_init_scale=0.1,
                 slope_prior_scale=0.25,
                 intercept_prior_scale=0.25,
                 kl_weight=1e-3,
                 name='re_classifier',
                 **kwargs):
        """2D CNN binary classifier with random effect slopes and intercept.

        Args:
            slope_post_init_scale (float, optional): S.d. for random normal initialization 
                of random slopes. Defaults to 0.1.
            intercept_post_init_scale (float, optional): S.d. for random normal 
                initialization of random intercepts. Defaults to 0.1.
            slope_prior_scale (float, optional): S.d. of prior distribution for random 
                slopes. Defaults to 0.25.
            intercept_prior_scale (float, optional): S.d. of prior distribution for random 
                intercepts. Defaults to 0.25.
            kl_weight ([type], optional): KL divergence weight. Defaults to 1e-3.
            name (str, optional): Model name. Defaults to 're_classifier'.
        """        
        super(RandomEffectsClassifier, self).__init__(name=name, **kwargs)
        
        # Single slope and intercept RE layer
        self.re_slopes = RandomEffects(units=512,
                                       post_loc_init_scale=slope_post_init_scale,
                                       prior_scale=slope_prior_scale,
                                       kl_weight=kl_weight,
                                       name='re_slopes')
        
        self.re_intercept = RandomEffects(units=1,
                                       post_loc_init_scale=intercept_post_init_scale,
                                       prior_scale=intercept_prior_scale,
                                       kl_weight=kl_weight,
                                       name='re_intercept')
        
    def call(self, inputs, return_layer_activations=False):
        x, z, = inputs
        c0 = self.conv0(x)
        c0 = self.bn0(c0)
        c0 = self.prelu0(c0)
        
        c1 = self.maxpool0(c0)
        c1 = self.conv1(c1)
        c1 = self.dropout1(c1)
        c1 = self.bn1(c1)
        c1 = self.prelu1(c1)
        
        c2 = self.maxpool1(c1)
        c2 = self.conv2(c2)
        c2 = self.bn2(c2)
        c2 = self.prelu2(c2)
        
        c3 = self.maxpool2(c2)
        c3 = self.conv3(c3)
        c3 = self.dropout3(c3)
        c3 = self.bn3(c3)
        c3 = self.prelu3(c3)
        
        c4 = self.maxpool3(c3)
        c4 = self.conv4(c4)
        c4 = self.bn4(c4)
        c4 = self.prelu4(c4)
        
        c5 = self.maxpool4(c4)
        c5 = self.conv5(c5)
        c5 = self.dropout5(c5)
        c5 = self.bn5(c5)
        c5 = self.prelu5(c5)
        
        c6 = self.maxpool5(c5)
        c6 = self.conv6(c6)
        c6 = self.prelu6(c6)
        h = self.flatten(c6)
        h = self.dense(h)
        slopes = self.re_slopes(z)
        intercepts = self.re_intercept(z)
        y = self.out(h * slopes)
        y = self.sigmoid(y + intercepts)
        
        if return_layer_activations:
            return c0, c1, c2, c3, c4, c5, c6, h, y
        else:
            return y
        
    
class MixedEffectsClassifier(DomainAdversarialImageClassifier):
    ''' Random linear slope and intercept introduced before last dense layer
    '''
    def __init__(self, 
                 n_clusters,
                 slope_post_init_scale=0.1,
                 intercept_post_init_scale=0.1,
                 slope_prior_scale=0.25,
                 intercept_prior_scale=0.25,
                 kl_weight=1e-3,
                 name='me_classifier', **kwargs):
        """2D CNN binary classifier using domain adversarial training to enforce the learning 
        of cluster-agnostic convolutional features. Additionally, cluster-specific random 
        slopes and intercepts are learned in the dense layers of the model. 

        Args:
            slope_post_init_scale (float, optional): S.d. for random normal initialization 
                of random slopes. Defaults to 0.1.
            intercept_post_init_scale (float, optional): S.d. for random normal 
                initialization of random intercepts. Defaults to 0.1.
            slope_prior_scale (float, optional): S.d. of prior distribution for random 
                slopes. Defaults to 0.25.
            intercept_prior_scale (float, optional): S.d. of prior distribution for random 
                intercepts. Defaults to 0.25.
            kl_weight ([type], optional): KL divergence weight. Defaults to 1e-3.
            name (str, optional): Model name. Defaults to 'me_classifier'.
        """      
      
        super(MixedEffectsClassifier, self).__init__(n_clusters, name=name, **kwargs)
                
        # Single slope and intercept RE layer
        self.re_slopes = RandomEffects(units=512,
                                       post_loc_init_scale=slope_post_init_scale,
                                       prior_scale=slope_prior_scale,
                                       kl_weight=kl_weight,
                                       name='re_slopes')
        
        self.re_intercept = RandomEffects(units=1,
                                       post_loc_init_scale=intercept_post_init_scale,
                                       prior_scale=intercept_prior_scale,
                                       kl_weight=kl_weight,
                                       name='re_intercept')
        
    def call(self, inputs, return_layer_activations=False):
        x, z = inputs
        c0 = self.classifier.conv0(x)
        c0 = self.classifier.bn0(c0)
        c0 = self.classifier.prelu0(c0)
        
        c1 = self.classifier.maxpool0(c0)
        c1 = self.classifier.conv1(c1)
        c1 = self.classifier.dropout1(c1)
        c1 = self.classifier.bn1(c1)
        c1 = self.classifier.prelu1(c1)
        
        c2 = self.classifier.maxpool1(c1)
        c2 = self.classifier.conv2(c2)
        c2 = self.classifier.bn2(c2)
        c2 = self.classifier.prelu2(c2)
        
        c3 = self.classifier.maxpool2(c2)
        c3 = self.classifier.conv3(c3)
        c3 = self.classifier.dropout3(c3)
        c3 = self.classifier.bn3(c3)
        c3 = self.classifier.prelu3(c3)
        
        c4 = self.classifier.maxpool3(c3)
        c4 = self.classifier.conv4(c4)
        c4 = self.classifier.bn4(c4)
        c4 = self.classifier.prelu4(c4)
        
        c5 = self.classifier.maxpool4(c4)
        c5 = self.classifier.conv5(c5)
        c5 = self.classifier.dropout5(c5)
        c5 = self.classifier.bn5(c5)
        c5 = self.classifier.prelu5(c5)
        
        c6 = self.classifier.maxpool5(c5)
        c6 = self.classifier.conv6(c6)
        c6 = self.classifier.prelu6(c6)
        h = self.classifier.flatten(c6)
        h = self.classifier.dense(h)
        slopes = self.re_slopes(z)
        intercepts = self.re_intercept(z)
        y = self.classifier.out(h * (1 + slopes))
        y = self.classifier.sigmoid(y + intercepts)
        
        if return_layer_activations:
            return c0, c1, c2, c3, c4, c5, c6, h, y
        else:
            return y
    
    def compile(self,
            loss_classifier=tf.keras.losses.BinaryCrossentropy(),
            loss_adversary=tf.keras.losses.CategoricalCrossentropy(),
            loss_classifier_weight=1.0,
            loss_gen_weight=0.1,
            metric_classifier=tf.keras.metrics.AUC(),
            opt_classifier=tf.keras.optimizers.Nadam(lr=0.0001),
            opt_adversary=tf.keras.optimizers.Nadam(lr=0.0001)
            ):
        """Compile model. Must be called before training or loading weights.

        Args:
            loss_classifier (loss, optional): Main classification loss. 
                Defaults to tf.keras.losses.BinaryCrossentropy().
            loss_adversary (loss, optional): Adversary classification loss. 
                Defaults to tf.keras.losses.CategoricalCrossentropy().
            loss_classifier_weight (float, optional): Main classification loss weight. 
                Defaults to 1.0.
            loss_gen_weight (float, optional): Generalization loss weight. 
                Defaults to 0.01.
            metric_classifier (metric, optional): Main classification metric. 
                Defaults to tf.keras.metrics.AUC().
            opt_classifier (optimizer, optional): Optimizer for main classifier. 
                Defaults to tf.keras.optimizers.Nadam(lr=0.0001).
            opt_adversary (optimizer, optional): Optimizer for adversary. 
                Defaults to tf.keras.optimizers.Nadam(lr=0.0001).
        """  
        
        super(MixedEffectsClassifier, self).compile(loss_classifier=loss_classifier,
                                                    loss_adversary=loss_adversary,
                                                    loss_classifier_weight=loss_classifier_weight,
                                                    loss_gen_weight=loss_gen_weight,
                                                    metric_classifier=metric_classifier,
                                                    opt_classifier=opt_classifier,
                                                    opt_adversary=opt_adversary)
        self.loss_kld_tracker = tf.keras.metrics.Mean(name='kld')
        
    @property
    def metrics(self):
        return [self.loss_classifier_tracker,
                self.loss_gen_tracker,
                self.loss_adversary_tracker,
                self.loss_kld_tracker,
                self.loss_total_tracker,
                self.metric_classifier]
        
    def _compute_update_loss(self, loss_class, loss_gen, training=True):
        '''Compute total loss and update loss running means'''
        self.loss_classifier_tracker.update_state(loss_class)
        self.loss_gen_tracker.update_state(loss_gen)
        if training:
            kld = tf.reduce_sum(self.re_slopes.losses) + tf.reduce_sum(self.re_intercept.losses)
            self.loss_kld_tracker.update_state(kld)
        else:
            # KLD can't be computed at inference time because posteriors are simplified to 
            # point estimates
            kld = 0
        
        loss_total = (self.loss_classifier_weight * loss_class) \
            + (self.loss_gen_weight * loss_gen) \
            + kld
        self.loss_total_tracker.update_state(loss_total)
        
        return loss_total
        
    
    def train_step(self, data):
        if len(data) == 3:
            (images, clusters), labels, sample_weights = data
        else:
            (images, clusters), labels = data
            sample_weights = None
            
        # train adversary
        with tf.GradientTape() as gt:
            clf_outs = self((images, clusters), return_layer_activations=True)
            layer_activations = clf_outs[:-1]
            clusters_pred = self.adversary(layer_activations)
            loss_adv = self.loss_adversary(clusters, clusters_pred, sample_weight=sample_weights)

        grads_adv = gt.gradient(loss_adv, self.adversary.trainable_weights)
        self.opt_adversary.apply_gradients(zip(grads_adv, self.adversary.trainable_weights))
        self.loss_adversary_tracker.update_state(loss_adv)
        
        # train main classifier
        with tf.GradientTape() as gt2:
            clf_outs = self((images, clusters), return_layer_activations=True)
            y_pred = clf_outs[-1]
            loss_class = self.loss_classifier(tf.reshape(labels, (-1, 1)), 
                                    tf.reshape(y_pred, (-1, 1)), 
                                    sample_weight=sample_weights)
                        
            layer_activations = clf_outs[:-1]
            clusters_pred = self.adversary(layer_activations)
            loss_gen = self.loss_adversary(clusters, clusters_pred, sample_weight=sample_weights)
            
            loss_total = self._compute_update_loss(loss_class, loss_gen)
            
        grads = gt2.gradient(loss_total, self.classifier.trainable_weights)
        self.opt_classifier.apply_gradients(zip(grads, self.classifier.trainable_weights))
        
        self.metric_classifier.update_state(labels, y_pred)
        return {m.name: m.result() for m in self.metrics}
    
    def test_step(self, data):
        (images, clusters), labels = data
        clf_outs = self((images, clusters), return_layer_activations=True)
        y_pred = clf_outs[-1]
        loss_class = self.loss_classifier(labels, y_pred)
        
        layer_activations = clf_outs[:-1]
        clusters_pred = self.adversary(layer_activations)
        loss_gen = self.loss_adversary(clusters, clusters_pred)
            
        _ = self._compute_update_loss(loss_class, loss_gen)
        
        self.metric_classifier.update_state(labels, y_pred)
        return {m.name: m.result() for m in self.metrics}