ARMED-MixedEffectsDL / armed / models / mlp_classifiers.py
mlp_classifiers.py
Raw
"""Simple neural networks for classification
"""
import tensorflow as tf
import tensorflow.keras.layers as tkl
from .random_effects import RandomEffects

class BaseMLP(tf.keras.Model):
    def __init__(self, name: str='mlp', **kwargs):
        """Basic MLP with 3 hidden layers of 4 neurons each.

        Args:
            name (str, optional): Model name. Defaults to 'mlp'.
        """        
        super(BaseMLP, self).__init__(name=name, **kwargs)

        self.dense0 = tkl.Dense(4, activation='relu', name=name + '_dense0')
        self.dense1 = tkl.Dense(4, activation='relu', name=name + '_dense1')
        self.dense2 = tkl.Dense(4, activation='relu', name=name + '_dense2')
        self.dense_out = tkl.Dense(1, activation='sigmoid', name=name + '_dense_out')
        
    def call(self, inputs):
        
        x = self.dense0(inputs)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense_out(x)
        
        return x
    
class ClusterCovariateMLP(BaseMLP):
    """
    Basic MLP that concatenates the site membership design matrix to the data.
    """
    def call(self, inputs):
        x, z = inputs
        
        x = tf.concat((x, z), axis=1)
        x = self.dense0(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense_out(x)
    
        return x
class MLPActivations(tkl.Layer):
    def __init__(self, last_activation: str='sigmoid', name: str='mlp_activations', **kwargs):
        """Basic MLP with 3 hidden layers of 4 neurons each. In addition to the
        prediction, also returns the activation of each layer. Intended to be
        used within a domain adversarial model.

        Args: 
        last_activation (str, optional): Activation of output layer. Defaults to 
            'sigmoid'. 
        name (str, optional): Model name. Defaults to 'mlp_activations'.
        """        
        super(MLPActivations, self).__init__(name=name, **kwargs)

        self.dense0 = tkl.Dense(4, activation='relu', name=name + '_dense0')
        self.dense1 = tkl.Dense(4, activation='relu', name=name + '_dense1')
        self.dense2 = tkl.Dense(4, activation='relu', name=name + '_dense2')
        self.dense_out = tkl.Dense(1, activation=last_activation, name=name + '_dense_out')
        
    def call(self, inputs):
        
        x0 = self.dense0(inputs)
        x1 = self.dense1(x0)
        x2 = self.dense2(x1)
        out = self.dense_out(x2)
        
        return x0, x1, x2, out
    
    def get_config(self):
        return {}
    
class Adversary(tkl.Layer):
    def __init__(self,
                 n_clusters: int, 
                 layer_units: list=[8, 8, 4],
                 name: str='adversary',
                 **kwargs):
        """Adversarial classifier. 

        Args:
            n_clusters (int): number of clusters (classes)
            layer_units (list, optional): Neurons in each layer. Can be a list of any
                length. Defaults to [8, 8, 8].
            name (str, optional): Model name. Defaults to 'adversary'.
        """        
        
        super(Adversary, self).__init__(name=name, **kwargs)
        
        self.n_clusters = n_clusters
        self.layer_units = layer_units
        
        self.layers = []
        for iLayer, neurons in enumerate(layer_units):
            self.layers += [tkl.Dense(neurons, 
                                      activation='relu', 
                                      name=name + '_dense' + str(iLayer))]
            
        self.layers += [tkl.Dense(n_clusters, activation='softmax', name=name + '_dense_out')]
        
    def call(self, inputs):
        x = inputs
        for layer in self.layers:
            x = layer(x)
            
        return x
    
    def get_config(self):
        return {'n_clusters': self.n_clusters,
                'layer_units': self.layer_units}
        
class DomainAdversarialMLP(tf.keras.Model):
    def __init__(self, 
                 n_clusters: int, 
                 adversary_layer_units: list=[8, 8, 4], 
                 name: str='da_mlp', 
                 **kwargs):
        """Domain adversarial MLP classifier. The main model learns the classification
        task while the adversary prevents it from learning cluster-related features. 

        Args:
            n_clusters (int): Number of clusters.
            adversary_layer_units (list, optional): Neurons in each layer of the 
                adversary. Defaults to [8, 8, 4].
            name (str, optional): Model name. Defaults to 'da_mlp'.
        """        
        
        super(DomainAdversarialMLP, self).__init__(name=name, **kwargs)

        self.classifier = MLPActivations(name='mlp')
        self.adversary = Adversary(n_clusters=n_clusters, 
                                   layer_units=adversary_layer_units,
                                   name='adversary')
        
    def call(self, inputs):
        x, z = inputs
        classifier_outs = self.classifier(x)
        pred_class = classifier_outs[-1]
        activations = tf.concat(classifier_outs[:3], axis=1)
        pred_cluster = self.adversary(activations)
        
        return pred_class, pred_cluster
    
    def compile(self,
                loss_class=tf.keras.losses.BinaryCrossentropy(),
                loss_adv=tf.keras.losses.CategoricalCrossentropy(),
                metric_class=tf.keras.metrics.AUC(curve='PR', name='auprc'),
                metric_adv=tf.keras.metrics.CategoricalAccuracy(name='acc'),
                opt_main=tf.keras.optimizers.Adam(lr=0.001),
                opt_adversary=tf.keras.optimizers.Adam(lr=0.001),
                loss_class_weight=1.0,
                loss_gen_weight=1.0,
                ):
        """Compile model with selected losses and metrics. Must be called before training.
        
        Loss weights apply to the main model: 
        total_loss = loss_class_weight * loss_class - loss_gen_weight * loss_adv

        Args:
            loss_class (loss, optional): Main classification loss. Defaults to 
                tf.keras.losses.BinaryCrossentropy().
            loss_adv (loss, optional): Adversary classification loss. Defaults to 
                tf.keras.losses.CategoricalCrossentropy().
            metric_class (metric, optional): Main classification metric. Defaults to 
                tf.keras.metrics.AUC(curve='PR', name='auprc').
            metric_adv (metric, optional): Adversary classification metric. Defaults to 
                tf.keras.metrics.CategoricalAccuracy(name='acc').
            opt_main (optimizer, optional): Main optimizer. Defaults to 
                tf.keras.optimizers.Adam(lr=0.001).
            opt_adversary (optimizer, optional): Adversary optimizer. Defaults to 
                tf.keras.optimizers.Adam(lr=0.001).
            loss_class_weight (float, optional): Classification loss weight. Defaults to 1.0.
            loss_gen_weight (float, optional): Generalization loss weight. Defaults to 1.0.
        """        
        
        super().compile()
        
        self.loss_class = loss_class
        self.loss_adv = loss_adv

        self.opt_main = opt_main
        self.opt_adversary = opt_adversary
        
        # Trackers for running mean of each loss
        self.loss_class_tracker = tf.keras.metrics.Mean(name='class_loss')
        self.loss_adv_tracker = tf.keras.metrics.Mean(name='adv_loss')
        self.loss_total_tracker = tf.keras.metrics.Mean(name='total_loss')

        self.metric_class = metric_class
        self.metric_adv = metric_adv

        self.loss_class_weight = loss_class_weight
        self.loss_gen_weight = loss_gen_weight    
        
    @property
    def metrics(self):
        return [self.loss_class_tracker,
                self.loss_adv_tracker,
                self.loss_total_tracker,
                self.metric_class,
                self.metric_adv]
        
    def train_step(self, data):
        # Unpack data, including sample weights if provided
        if len(data) == 3:
            (data, clusters), labels, sample_weights = data
        else:
            (data, clusters), labels = data
            sample_weights = None
        
        # Get hidden layer activations from classifier and train the adversary    
        activations = tf.concat(self.classifier(data)[:-1], axis=1)
        with tf.GradientTape() as gt:
            pred_cluster = self.adversary(activations)
            loss_adv = self.loss_adv(clusters, pred_cluster, sample_weight=sample_weights)
            
        grads_adv = gt.gradient(loss_adv, self.adversary.trainable_variables)
        self.opt_adversary.apply_gradients(zip(grads_adv, self.adversary.trainable_variables))
        
        self.metric_adv.update_state(clusters, pred_cluster)
        self.loss_adv_tracker.update_state(loss_adv)
        
        # Train the main classifier
        with tf.GradientTape() as gt2:
            pred_class, pred_cluster = self((data, clusters))
            loss_class = self.loss_class(labels, pred_class, sample_weight=sample_weights)
            loss_adv = self.loss_adv(clusters, pred_cluster, sample_weight=sample_weights)
            
            total_loss = (self.loss_class_weight * loss_class) \
                - (self.loss_gen_weight * loss_adv)

        grads_class = gt2.gradient(total_loss, self.classifier.trainable_variables)
        self.opt_main.apply_gradients(zip(grads_class, self.classifier.trainable_variables))
        
        self.metric_class.update_state(labels, pred_class)
        self.loss_class_tracker.update_state(loss_class)
        self.loss_total_tracker.update_state(total_loss)
        
        return {m.name: m.result() for m in self.metrics}
    
    def test_step(self, data):
        (data, clusters), labels = data
                        
        pred_class, pred_cluster = self((data, clusters))
        loss_class = self.loss_class(labels, pred_class)
        loss_adv = self.loss_adv(clusters, pred_cluster)
            
        total_loss = (self.loss_class_weight * loss_class) \
            - (self.loss_gen_weight * loss_adv)
                    
        self.metric_class.update_state(labels, pred_class)
        self.metric_adv.update_state(clusters, pred_cluster)
        
        self.loss_class_tracker.update_state(loss_class)
        self.loss_adv_tracker.update_state(loss_adv)
        self.loss_total_tracker.update_state(total_loss)
        
        return {m.name: m.result() for m in self.metrics}
    

class RandomEffectsLinearSlopeIntercept(tkl.Layer):
    def __init__(self, 
                 slopes: int,
                 slope_posterior_init_scale: float=0.1, 
                 intercept_posterior_init_scale: float=0.1, 
                 slope_prior_scale: float=0.1,
                 intercept_prior_scale: float=0.1,
                 kl_weight: float=0.001, 
                 name: str='randomeffects', **kwargs):
        """Layer that learns a random linear slope and intercept. When called on an input
        (x, z), it returns a tuple of (f(random_slope(z) * x), random_intercept(z)).

        Args:
            slopes ([type]): dimensionality of the slopes (i.e. the number of features)
            slope_posterior_init_scale (float, optional): Scale for initializing slope 
                posterior means with a random normal distribution. Defaults to 0.1.
            intercept_posterior_init_scale (float, optional): Scale for initializing intercept 
                posterior means with a random normal distribution. Defaults to 0.1.
            slope_prior_scale (float, optional): Scale of slope prior distribution. Defaults to 0.1.
            intercept_prior_scale (float, optional): Intercept of intercept prior distribution. 
                Defaults to 0.1.
            kl_weight (float, optional): KL divergence loss weight. Defaults to 0.001.
            name (str, optional): Mode name. Defaults to 'randomeffects'.
        """        
        super(RandomEffectsLinearSlopeIntercept, self).__init__(name=name, **kwargs)
    
        self.slopes = slopes
        self.slope_posterior_init_scale = slope_posterior_init_scale
        self.intercept_posterior_init_scale = intercept_posterior_init_scale
        self.slope_prior_scale = slope_prior_scale
        self.intercept_prior_scale = intercept_prior_scale
        self.kl_weight= kl_weight
        
        self.re_slope = RandomEffects(slopes, 
                                      post_loc_init_scale=slope_posterior_init_scale,
                                      prior_scale=slope_prior_scale,
                                      kl_weight=kl_weight, name=name + '_re_slope')
        self.dense_out = tkl.Dense(1, name=name + '_re_out')
        
        self.re_int = RandomEffects(1, 
                                    post_loc_init_scale=intercept_posterior_init_scale,
                                    prior_scale=intercept_prior_scale,
                                    kl_weight=kl_weight, 
                                    name=name + '_re_int')
  
    def call(self, inputs, training=None):
        x, z = inputs        
        slope = self.re_slope(z, training=training)
        # prod = self.dense_out(x * slope)
        prod = tf.reduce_sum(x * slope, axis=1, keepdims=True)
        intercept = self.re_int(z, training=training)
        
        return  prod, intercept
    
    def get_config(self):
        return {'slopes': self.slopes,
                'slope_posterior_init_scale': self.slope_posterior_init_scale,
                'intercept_posterior_init_scale': self.intercept_posterior_init_scale,
                'slope_prior_scale': self.slope_prior_scale,
                'intercept_prior_scale': self.intercept_prior_scale,
                'kl_weight': self.kl_weight}
        
class MixedEffectsMLP(DomainAdversarialMLP):
    def __init__(self, n_features: int, n_clusters: int, 
                 adversary_layer_units: list=[8, 8, 4], 
                 slope_posterior_init_scale: float=0.1, 
                 intercept_posterior_init_scale: float=0.1, 
                 slope_prior_scale: float=0.1,
                 intercept_prior_scale: float=0.1,
                 kl_weight: float=0.001,
                 name: str='me_mlp', 
                 **kwargs):
        """Mixed effects MLP classifier. Includes an adversarial classifier to 
        disentangle the predictive features from the cluster-specific features. 
        The cluster-specific features are then learned by a random effects layer. 
        
        This architecture includes linear random slopes (to be multiplied by the 
        input) and random intercept. The model output is 
        (fixed effect output) + (random slopes) * X + (random intercept)

        Args:
            n_features (int): Number of features.
            n_clusters (int): Number of clusters.
            adversary_layer_units (list, optional): Neurons in each layer of the 
                adversary. Defaults to [8, 8, 4].
            slope_posterior_init_scale (float, optional): Scale for initializing slope 
                posterior means with a random normal distribution. Defaults to 0.1.
            intercept_posterior_init_scale (float, optional): Scale for initializing intercept 
                posterior means with a random normal distribution. Defaults to 0.1.
            slope_prior_scale (float, optional): Scale of slope prior distribution. Defaults to 0.1.
            intercept_prior_scale (float, optional): Intercept of intercept prior distribution. 
                Defaults to 0.1.
            kl_weight (float, optional): KL divergence loss weight. Defaults to 0.001.
            name (str, optional): Model name. Defaults to 'me_mlp'.
        """        
    
        super(MixedEffectsMLP, self).__init__(n_clusters=n_clusters,
                                              adversary_layer_units=adversary_layer_units,
                                              name=name, **kwargs)
        self.classifier = MLPActivations(last_activation='linear', name='mlp')

        self.randomeffects = RandomEffectsLinearSlopeIntercept(
                        n_features,
                        slope_posterior_init_scale=slope_posterior_init_scale,
                        intercept_posterior_init_scale=intercept_posterior_init_scale,
                        slope_prior_scale=slope_prior_scale,
                        intercept_prior_scale=intercept_prior_scale,
                        kl_weight=kl_weight)

        
    def call(self, inputs, training=None):
        x, z = inputs
        fe_outs = self.classifier(x)
        pred_class_fe = tf.nn.sigmoid(fe_outs[-1])
                
        re_prod, re_int = self.randomeffects((x, z), training=training)
        pred_class_me = tf.nn.sigmoid(re_prod + re_int + fe_outs[-1])     
        
        fe_activations = tf.concat(fe_outs[:3], axis=1)
        pred_cluster = self.adversary(fe_activations)
                
        return pred_class_me, pred_class_fe, pred_cluster
    
    def compile(self,
                loss_class=tf.keras.losses.BinaryCrossentropy(),
                loss_adv=tf.keras.losses.CategoricalCrossentropy(),
                metric_class_me=tf.keras.metrics.AUC(curve='PR', name='auprc'),
                metric_class_fe=tf.keras.metrics.AUC(curve='PR', name='auprc_fe'),
                metric_adv=tf.keras.metrics.CategoricalAccuracy(name='acc'),
                opt_main=tf.keras.optimizers.Adam(lr=0.001),
                opt_adversary=tf.keras.optimizers.Adam(lr=0.001),
                loss_class_me_weight=1.0,
                loss_class_fe_weight=1.0,
                loss_gen_weight=1.0,
                ):
        """Compile model with selected losses and metrics. Must be called before training.
        
        Loss weights apply to the main model: 
        total_loss = loss_class_me_weight * loss_class_me + loss_class_fe_weight * loss_class_fe
            - loss_gen_weight * loss_adv

        Args:
            loss_class (loss, optional): Main classification loss. This applies to both the 
                mixed and fixed effects-based classifications. Defaults to 
                tf.keras.losses.BinaryCrossentropy().
            loss_adv (loss, optional): Adversary classification loss. Defaults to 
                tf.keras.losses.CategoricalCrossentropy().
            metric_class_me (metric, optional): Metric for classification using mixed effects. 
                Defaults to tf.keras.metrics.AUC(curve='PR', name='auprc').
            metric_class_fe (metric, optional): Metric for classification using fixed effects. 
                Defaults to tf.keras.metrics.AUC(curve='PR', name='auprc_fe').
            metric_adv (metric, optional): Adversary classification metric. Defaults to 
                tf.keras.metrics.CategoricalAccuracy(name='acc').
            opt_main (optimizer, optional): Main optimizer. Defaults to 
                tf.keras.optimizers.Adam(lr=0.001).
            opt_adversary (optimizer, optional): Adversary optimizer. Defaults to 
                tf.keras.optimizers.Adam(lr=0.001).
            loss_class_me_weight (float, optional): Weight for classification using mixed 
                effects. Defaults to 1.0.
            loss_class_fe_weight (float, optional): Weight for classification using fixed 
                effects. Defaults to 1.0.
            loss_gen_weight (float, optional): Generalization loss weight. Defaults to 1.0.
        """  
        
        super().compile()
        
        self.loss_class = loss_class
        self.loss_adv = loss_adv

        self.opt_main = opt_main
        self.opt_adversary = opt_adversary
        
        # Loss trackers
        self.loss_class_me_tracker = tf.keras.metrics.Mean(name='class_me_loss')
        self.loss_class_fe_tracker = tf.keras.metrics.Mean(name='class_fe_loss')
        self.loss_adv_tracker = tf.keras.metrics.Mean(name='adv_loss')
        self.loss_total_tracker = tf.keras.metrics.Mean(name='total_loss')

        self.metric_class_me = metric_class_me
        self.metric_class_fe = metric_class_fe
        self.metric_adv = metric_adv

        self.loss_class_me_weight = loss_class_me_weight
        self.loss_class_fe_weight = loss_class_fe_weight
        self.loss_gen_weight = loss_gen_weight    
        
        # Unneeded
        del self.loss_class_tracker, self.loss_class_weight, self.metric_class
        
    @property
    def metrics(self):
        return [self.loss_class_me_tracker,
                self.loss_class_fe_tracker,
                self.loss_adv_tracker,
                self.loss_total_tracker,
                self.metric_class_me,
                self.metric_class_fe,
                self.metric_adv]
        
    def train_step(self, data):
        # Unpack data, including sample weights if provided
        if len(data) == 3:
            (data, clusters), labels, sample_weights = data
        else:
            (data, clusters), labels = data
            sample_weights = None
        
        # Get hidden layer activations from classifier and train the adversary       
        activations = tf.concat(self.classifier(data)[:-1], axis=1)
        with tf.GradientTape() as gt:
            pred_cluster = self.adversary(activations)
            loss_adv = self.loss_adv(clusters, pred_cluster, sample_weight=sample_weights)
            
        grads_adv = gt.gradient(loss_adv, self.adversary.trainable_variables)
        self.opt_adversary.apply_gradients(zip(grads_adv, self.adversary.trainable_variables))
        
        self.metric_adv.update_state(clusters, pred_cluster)
        self.loss_adv_tracker.update_state(loss_adv)
        
        # Train the main classifier 
        with tf.GradientTape() as gt2:
            pred_class_me, pred_class_fe, pred_cluster = self((data, clusters), training=True)
            loss_class_me = self.loss_class(labels, pred_class_me, sample_weight=sample_weights)
            loss_class_fe = self.loss_class(labels, pred_class_fe, sample_weight=sample_weights)
            loss_adv = self.loss_adv(clusters, pred_cluster, sample_weight=sample_weights)
            
            total_loss = (self.loss_class_me_weight * loss_class_me) \
                + (self.loss_class_fe_weight * loss_class_fe) \
                - (self.loss_gen_weight * loss_adv) \
                + self.randomeffects.losses

        lsVars = self.classifier.trainable_variables + self.randomeffects.trainable_variables
        grads_class = gt2.gradient(total_loss, lsVars)
        self.opt_main.apply_gradients(zip(grads_class, lsVars))
        
        self.metric_class_me.update_state(labels, pred_class_me)
        self.metric_class_fe.update_state(labels, pred_class_fe)
        self.loss_class_me_tracker.update_state(loss_class_me)
        self.loss_class_fe_tracker.update_state(loss_class_fe)
        self.loss_total_tracker.update_state(total_loss)
        
        return {m.name: m.result() for m in self.metrics}
    
    def test_step(self, data):
        (data, clusters), labels = data
                        
        pred_class_me, pred_class_fe, pred_cluster = self((data, clusters), training=False)
        loss_class_me = self.loss_class(labels, pred_class_me)
        loss_class_fe = self.loss_class(labels, pred_class_fe)
        loss_adv = self.loss_adv(clusters, pred_cluster)
            
        total_loss = (self.loss_class_me_weight * loss_class_me) \
                + (self.loss_class_fe_weight * loss_class_fe) \
                - (self.loss_gen_weight * loss_adv) \
                + self.randomeffects.losses
                    
        self.metric_class_me.update_state(labels, pred_class_me)
        self.metric_class_fe.update_state(labels, pred_class_fe)
        self.metric_adv.update_state(clusters, pred_cluster)
        
        self.loss_class_me_tracker.update_state(loss_class_me)
        self.loss_class_fe_tracker.update_state(loss_class_fe)
        self.loss_adv_tracker.update_state(loss_adv)
        self.loss_total_tracker.update_state(total_loss)
        
        return {m.name: m.result() for m in self.metrics}
        
        
class MixedEffectsMLPNonlinearSlope(MixedEffectsMLP):
    def __init__(self, n_features: int, n_clusters: int, 
                 adversary_layer_units: list=[8, 8, 4], 
                 slope_posterior_init_scale: float=0.1, 
                 intercept_posterior_init_scale: float=0.1, 
                 slope_prior_scale: float=0.1,
                 intercept_prior_scale: float=0.1,
                 kl_weight: float=0.001,
                 name: str='me_mlp', 
                 **kwargs):
        """Mixed effects MLP classifier. Includes an adversarial classifier to 
        disentangle the predictive features from the cluster-specific features. 
        The cluster-specific features are then learned by a random effects layer. 

        This architecture includes nonlinear random slopes (to be multiplied by the 
        penultimate layer output of the fixed effects submodel) and random intercept. 
        The model output is 
        (fixed effect output) + (random slopes) * (penultimate FE layer output) + (random intercept)

        Args:
            n_features (int): Number of features.
            n_clusters (int): Number of clusters.
            adversary_layer_units (list, optional): Neurons in each layer of the 
                adversary. Defaults to [8, 8, 4].
            slope_posterior_init_scale (float, optional): Scale for initializing slope 
                posterior means with a random normal distribution. Defaults to 0.1.
            intercept_posterior_init_scale (float, optional): Scale for initializing intercept 
                posterior means with a random normal distribution. Defaults to 0.1.
            slope_prior_scale (float, optional): Scale of slope prior distribution. Defaults to 0.1.
            intercept_prior_scale (float, optional): Intercept of intercept prior distribution. 
                Defaults to 0.1.
            kl_weight (float, optional): KL divergence loss weight. Defaults to 0.001.
            name (str, optional): Model name. Defaults to 'me_mlp'.
        """       
        del n_features # unused
    
        super(MixedEffectsMLP, self).__init__(n_clusters=n_clusters,
                                              adversary_layer_units=adversary_layer_units,
                                              name=name, **kwargs)
        self.classifier = MLPActivations(last_activation='linear', name='mlp')

        self.randomeffects = RandomEffectsLinearSlopeIntercept(
                        slopes=4,
                        slope_posterior_init_scale=slope_posterior_init_scale,
                        intercept_posterior_init_scale=intercept_posterior_init_scale,
                        slope_prior_scale=slope_prior_scale,
                        intercept_prior_scale=intercept_prior_scale,
                        kl_weight=kl_weight)

    def call(self, inputs, training=None):
        x, z = inputs
        fe_outs = self.classifier(x)
        pred_class_fe = tf.nn.sigmoid(fe_outs[-1])
        
        # Penultimate FE layer output
        fe_latents = fe_outs[-2]        
        
        re_prod, re_int = self.randomeffects((fe_latents, z), training=training)
        pred_class_me = tf.nn.sigmoid(re_prod + re_int + pred_class_fe)     
        
        fe_activations = tf.concat(fe_outs[:3], axis=1)
        pred_cluster = self.adversary(fe_activations)
                
        return pred_class_me, pred_class_fe, pred_cluster