import tensorflow as tf from tensorflow import keras from tensorflow.keras.layers import Dense, Activation,BatchNormalization import tensorflow.keras.layers as tkl from scMEDAL.models.random_effects import ClusterScaleBiasBlock from tensorflow.keras import layers from collections.abc import Iterable """ Author: Aixa Andrade with collaboration of Son Nguyen. Code inspired in the original ARMED's convolutional autoencoder code written by Kevin Nguyen for the melanoma experiment. Some snippets of code are borrowed as they are from Kevin Nguyen et al 2023 (ARMED paper (2023)). This code uses custom Dense layers for building custom scMEDAL vector autoencoders. """ class TiedDenseTranspose(tf.keras.layers.Layer): """ A tied dense transpose layer that shares weights with a source dense layer. Attributes: source_layer (tf.keras.layers.Dense): Source dense layer to tie weights with. activation (tf.keras.activations): Activation function for the layer. units (int): Number of units for the layer. kernel (tf.Variable): Shared weights with the source layer. bias_t (tf.Variable): Bias for the layer. """ # Inspired in Medium article of Building an autoencoder with tied weights in keras by Laurence Mayrand-Provencher (2019) # https://medium.com/@lmayrandprovencher/building-an-autoencoder-with-tied-weights-in-keras-c4a559c529a2 def __init__(self, source_layer: tf.keras.layers.Dense, units, activation=None, **kwargs): """ Initialize the TiedDenseTranspose layer. Args: source_layer (tf.keras.layers.Dense): Source dense layer to tie weights with. units (int): Number of units for the layer. activation (str, optional): Activation function to use. Defaults to None. **kwargs: Additional keyword arguments. """ self.source_layer = source_layer self.activation = tf.keras.activations.get(activation) self.units = units super().__init__(**kwargs) def build(self, batch_input_shape): """Build the layer weights.""" # it only shares weights but not biases self.kernel = self.source_layer.kernel # initializes bias as zeros self.bias_t = self.add_weight(name='bias_t', shape=(self.units,), initializer="zeros") super().build(batch_input_shape) def call(self, inputs): """Apply the layer operations on the input tensor.""" return self.activation(tf.matmul(inputs, self.kernel, transpose_b=True) + self.bias_t) class Encoder(tf.keras.Model): """ Encoder Layer for Neural Networks with optional batch normalization. Attributes: n_latent_dim (int): Number of latent dimensions. layer_units (list): List of units for each dense layer. return_layer_activations (bool): Flag to determine if layer activations should be returned. return_encoder_layers (bool): Flag to determine if encoder layers should be returned. layers (dict): Dictionary containing all the layers. dense_blocks (dict): Dictionary containing dense layers and batch normalization layers. """ def __init__(self, n_latent_dims: int=2, layer_units: list=[9, 7, 5], return_layer_activations: bool=False, return_encoder_layers: bool=False, use_batch_norm: bool=False, # Flag to determine if batch normalization should be used name='encoder', **kwargs): super(Encoder, self).__init__(name=name, **kwargs) self.n_latent_dim = n_latent_dims self.layer_units = layer_units self.return_layer_activations = return_layer_activations self.return_encoder_layers = return_encoder_layers self.use_batch_norm = use_batch_norm # Create dictionaries for the blocks and layers self.dense_blocks = {} self.all_layers = {} # Fill the dictionaries using a for loop for i, n_units in enumerate(self.layer_units): key_name = "dense_" + str(i) dense_layer = Dense(units=n_units, activation=None, name=key_name) # activation is None if batch norm is used self.dense_blocks[key_name] = dense_layer self.all_layers[key_name] = [dense_layer] if self.use_batch_norm: bn_key_name = "batch_norm_" + str(i) bn_layer = BatchNormalization(name=bn_key_name) self.all_layers[key_name].append(bn_layer) # Add activation layer separately if using batch norm activation_layer = tf.keras.layers.Activation('selu') self.all_layers[key_name].append(activation_layer) # Define the latent layer self.dense_latent = Dense(units=self.n_latent_dim, activation="selu", name="dense_latent") self.all_layers["dense_latent"] = [self.dense_latent] def call(self, inputs, training=None): x = inputs layer_activations = [] # Iterate through the layers using the dictionary for key, layers in self.all_layers.items(): for layer in layers: x = layer(x, training=training) # ensure to pass training parameter for batch normalization layer_activations.append(x) if self.return_layer_activations: return layer_activations elif self.return_encoder_layers: return self.all_layers, x else: return x class Decoder(tf.keras.Model): """ Decoder Layer for Neural Networks. The Decoder layers can be Tied with the Encoder layers if encoder layers are provided and if the tied_weights = True. Attributes: encoder_dense_layers (list): List of encoder dense layers to tie weights with. in_shape (tuple): Input shape for the autoencoder (encoder input shape). Input shape encoder = output shape decoder layer_units (list): List of units for each dense layer. last_activation (str): Last activation function for the decoder. layers (dict): Dictionary containing all the layers. """ def __init__(self, in_shape: tuple, encoder_layers: list = [], layer_units: list=[9,7,5], last_activation: str='sigmoid', name='decoder', tied_weights = True, **kwargs): """ Initialize the Decoder. Args: encoder_layers (list, optional): List of encoder layers to tie weights with. If you want a TiedDecoder, you have to provide the encoder_layers. Defaults to empty list. in_shape (tuple): Input shape for the autoencoder (encoder input shape). Input shape encoder = output shape decoder layer_units (list, optional): List containing the number of units for each dense layer. Defaults to [784, 392]. last_activation (str, optional): Last activation function for the decoder. Defaults to "sigmoid". name (str, optional): Name of the layer. Defaults to 'decoder'. tied_weights (bool, optional): If True, the layers of the Decoder are Tied with the Encoder. Else: The layers of the Decoder are Dense. Defaults to True. **kwargs: Additional keyword arguments. """ super(Decoder, self).__init__(name=name, **kwargs) self.in_shape = in_shape self.layer_units = layer_units self.last_activation = last_activation self.all_layers = {} self.tied_weights = tied_weights if (self.tied_weights == True )& (len(encoder_layers)>0): #If tied weights = True --> decoder layers are tied with the encoder layers #print("encoder layers",encoder_layers) # get encoder dense layers # encoder_dense_layers = [layer for layer in encoder_layers if "dense" in layer.name] # def is_iterable(obj): # """ Check if the object is iterable but not a string """ # return isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)) def is_iterable(obj): return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)) # Using a nested list comprehension to handle both nested and flat list scenarios encoder_dense_layers = [layer for item in encoder_layers for layer in (item if is_iterable(item) else [item]) if "dense" in layer.name] self.encoder_dense_layers = encoder_dense_layers # build the decoder reverse looping through the encoder layers for n_units, e_layer in zip(self.layer_units[::-1], self.encoder_dense_layers[1:][::-1]): key_name = e_layer.name + "_t" self.all_layers[key_name] = TiedDenseTranspose(source_layer=e_layer, units=n_units, activation="selu", name=key_name) # out decoder: out layer shares weights with encoder first layer # defining layer with last activation # the last activation is sigmoid to make sure the values are between zero and one key_name = "dense_out" self.all_layers[key_name] = TiedDenseTranspose(source_layer=self.encoder_dense_layers[0], units=self.in_shape[-1], activation=self.last_activation, name=key_name) else: #If tied weights = False --> decoder layers are Dense layers # build the decoder reverse looping through the layer units for i,n_units in enumerate(self.layer_units[::-1]): key_name = "dense_"+str(len(self.layer_units)-i) self.all_layers[key_name] = Dense(units=n_units, activation="selu", name=key_name) # the last activation is sigmoid to make sure the values are between zero and one key_name = "dense_out" self.all_layers[key_name] = Dense(units=self.in_shape[-1], activation=self.last_activation, name=key_name) def call(self, inputs, training=None): """ Call the decoder layer with input data. Args: inputs (tf.Tensor): Input tensor data. training (bool, optional): If in training mode or not. Defaults to None. Returns: tf.Tensor: Processed output tensor. """ x = inputs # apply transposed dense layers (decoder) for key, layer in self.all_layers.items(): #print(layer.name) #x = layer(x) x = layer(x, training=training) return x class AE(tf.keras.Model): """ Autoencoder (AE) Model with tied weights. Attributes: in_shape (tuple): Input shape for the AE. layer_units (list): List of units for each dense layer in the encoder. n_latent_dims (int): Number of latent dimensions for the encoder. last_activation (str): Last activation function for the decoder. return_layer_activations (bool): Whether to return layer activations from the encoder. encoder (Encoder): Encoder part of the AE. decoder (Decoder): Decoder part of the AE, it has Tied weights with the Encoder. """ def __init__(self, in_shape: tuple, n_latent_dims: int = 2, layer_units: list = [9,7,5], last_activation: str = "sigmoid", return_layer_activations: bool = False, use_batch_norm: bool=False, name='ae', **kwargs): """ Initialize the AE model. Args: in_shape (tuple): Input shape for the AE. n_latent_dims (int, optional): Number of latent dimensions for the encoder. Defaults to 64. layer_units (list, optional): List containing the number of units for each dense layer in the encoder. Defaults to [784, 392]. last_activation (str, optional): Last activation function for the decoder. Defaults to "sigmoid". return_layer_activations (bool, optional): Whether to return layer activations from the encoder. Defaults to False. name (str, optional): Name of the model. Defaults to 'ae'. **kwargs: Additional keyword arguments. """ super(AE, self).__init__(name=name, **kwargs) self.in_shape = in_shape self.layer_units = layer_units self.n_latent_dims = n_latent_dims self.last_activation = last_activation self.return_layer_activations = return_layer_activations self.use_batch_norm = use_batch_norm self.encoder = Encoder(n_latent_dims=n_latent_dims, layer_units=layer_units, return_layer_activations=self.return_layer_activations, use_batch_norm=self.use_batch_norm) # Assuming the Encoder class returns a dictionary for its layers attribute encoder_layers_list = list(self.encoder.all_layers.values()) self.decoder = Decoder(in_shape=self.in_shape, encoder_layers=encoder_layers_list, layer_units=self.layer_units, last_activation=self.last_activation) def call(self, inputs, training=None): """ Call the AE model with input data. Args: inputs (tf.Tensor): Input tensor data. training (bool, optional): If in training mode or not. Defaults to None. Returns: tf.Tensor: Processed output tensor. """ # Get the encoder output. If return_layer_activations is True, # the encoder returns all layer activations, else just the latent representation. encoder_output = self.encoder(inputs, training=training) # Determine the latent representation based on return_layer_activations flag latent = encoder_output[-1] if self.return_layer_activations else encoder_output out = self.decoder(latent, training=training) return out class AEC(tf.keras.Model): """ An autoencoder-based classifier model built using TensorFlow's Keras API. This model is a combination of an autoencoder for unsupervised learning and a classifier for supervised learning. It is designed to work with input data in the specified input shape, compress it into a latent space using an encoder, and then reconstruct the input from this compressed representation using a decoder. Additionally, it uses the latent representation for classification purposes. The autoencoder part of the model is a 'tied' autoencoder, meaning that the weights of the encoder are tied to the weights of the decoder. This type of architecture can be beneficial for certain types of data compression and reconstruction tasks. Parameters: in_shape (tuple): The shape of the input data. n_latent_dims (int, optional): The number of dimensions for the latent space representation. Default is 2. layer_units (list, optional): The number of units in each layer of the encoder (and by extension, the decoder). Default is [9, 7, 5]. last_activation (str, optional): The activation function to be used in the last layer of the autoencoder. Default is 'sigmoid'. return_layer_activations (bool, optional): Flag to determine whether the encoder should return all layer activations or just the final latent representation. Default is False. n_pred (int, optional): The number of prediction classes for the classifier. Default is 20. layer_units_latent_classifier (list, optional): The number of units in each layer of the classifier. Default is [2]. name (str, optional): Name of the model. Default is 'ae_class'. The model has three main components: - An encoder that reduces the input to a lower-dimensional latent space. - A decoder that reconstructs the input from the latent space. - A classifier that uses the latent space representation for classification tasks. The `call` method of the model takes in input data and optionally a training flag and returns a dictionary with two keys: 'reconstruction_output' for the output of the autoencoder, and 'classification_output' for the output of the classifier. Example: model = AE_classifier(in_shape=(28, 28, 1)) # For training or inference output = model(data) """ def __init__(self, in_shape: tuple, n_latent_dims: int = 2, layer_units: list = [9,7,5], last_activation: str = "sigmoid", return_layer_activations = False, n_pred=20, layer_units_latent_classifier=[2], use_batch_norm: bool=False, name='aec', **kwargs): super(AEC, self).__init__(name=name, **kwargs) self.in_shape = in_shape self.layer_units = layer_units self.n_latent_dims = n_latent_dims self.last_activation = last_activation self.return_layer_activations = return_layer_activations self.n_pred = n_pred self.layer_units_latent_classifier = layer_units_latent_classifier self.use_batch_norm = use_batch_norm self.encoder = Encoder(n_latent_dims=n_latent_dims, layer_units=layer_units, return_layer_activations=self.return_layer_activations, use_batch_norm=self.use_batch_norm) # Assuming the Encoder class returns a dictionary for its layers attribute encoder_layers_list = list(self.encoder.all_layers.values()) # Tied AE. ENCODER WEIGHTS = DECODER self.decoder = Decoder(in_shape=self.in_shape, encoder_layers=encoder_layers_list, layer_units=self.layer_units, last_activation=self.last_activation) self.classifier = Classifier(n_clusters=self.n_pred,layer_units = self.layer_units_latent_classifier) def call(self, inputs, training=None): # Get the encoder output. If return_layer_activations is True, # the encoder returns all layer activations, else just the latent representation. #print("model inputs shape",inputs.shape) encoder_output = self.encoder(inputs, training=training) # Determine the latent representation based on return_layer_activations flag latent = encoder_output[-1] if self.return_layer_activations else encoder_output # Pass the latent representation through the decoder and classifier recon = self.decoder(latent, training=training) classification = self.classifier(latent) #print("recon pred shape",recon.shape,"class pred shape",classification.shape) return {'reconstruction_output': recon, 'classification_output': classification} class AdversarialClassifier(tkl.Layer): def __init__(self, n_clusters: int, n_latent_dims: int=2, layer_units: list=[5, 4], name: str='adversary', **kwargs): """Adversarial classifier. Args: n_clusters (int): number of clusters (classes) layer_units (list, optional): Neurons in each layer. Can be a list of any length. Defaults to [8, 8, 8]. name (str, optional): Model name. Defaults to 'adversary'. """ super(AdversarialClassifier, self).__init__(name=name, **kwargs) self.n_clusters = n_clusters self.layer_units = layer_units self.all_layers = [] for iLayer, neurons in enumerate(layer_units): self.all_layers += [tkl.Dense(neurons, activation='relu', name=name + '_dense' + str(iLayer))] self.all_layers += [tkl.Dense(self.n_clusters , activation='softmax', name=name + '_dense_out')] def call(self, inputs): if type(inputs) is list: inputs = tf.concat(inputs, axis=-1) x = inputs for layer in self.all_layers: x = layer(x) return x def get_config(self): return {'n_clusters': self.n_clusters, 'layer_units': self.layer_units} def get_config(self): return {'n_clusters': self.n_clusters, 'layer_units': self.layer_units} class Classifier(tf.keras.layers.Layer): """ A custom Keras Layer for classification tasks. This layer implements a classifier with a user-defined number of dense layers followed by an output layer for clustering (can be used to predict donors/batches). Optionally, it can also implement another classifier subnet for a second prediction (can be used to predict celltypes). Attributes: ----------- layer_units : list List of integers specifying the number of units in each dense layer. layers_cluster : dict Dictionary containing dense layers for clustering. n_clusters : int Number of clusters for the classification task. n_pred : int Number of predictions. get_pred : bool Flag to determine if prediction subnet should be built. layers_pred : dict Dictionary containing dense layers for prediction, built only if get_pred is True. Methods: -------- call(inputs, training=None): Perform the forward pass for the clustering and optionally for prediction. get_config(): Returns a dictionary containing the configuration of the classifier (i.e., n_clusters). Parameters: ----------- n_clusters : int Number of clusters for the classification task. layer_units : list, optional List of integers specifying the number of units in each dense layer. Defaults to [32, 16]. n_pred : int, optional Number of predictions, only used if get_pred is True. Defaults to 4. get_pred : bool, optional Flag to determine if a subnet for predictions should be built. Defaults to False. name : str, optional Name of the layer. Defaults to 'latent_classifier'. **kwargs : Additional keyword arguments inherited from tf.keras.layers.Layer. """ def __init__(self, n_clusters: int, layer_units: list=[2], n_pred: int = 10, get_pred = False, name='latent_classifier', **kwargs): super(Classifier, self).__init__(name=name, **kwargs) self.layer_units = layer_units self.layers_cluster = {} self.n_clusters = n_clusters self.n_pred = n_pred self.get_pred = get_pred if self.get_pred: self.layers_pred = {} # Fill the dictionaries using a for loop for i, n_units in enumerate(self.layer_units): key_name = "dense_" + str(i) self.layers_cluster[key_name] = Dense(units=n_units, activation="relu", name=key_name) # if get_pred ==True, build dense subnet to get_predictions if self.get_pred: self.layers_pred[key_name] = Dense(units=n_units, activation="relu", name=key_name) #This layer predicts the number of clusters self.layers_cluster["dense_out"] = Dense(self.n_clusters, activation='softmax', name=name + '_out') if self.get_pred: #if get_pred ==True: use softmax to pred the classes self.layers_pred["dense_out"] = Dense(self.n_pred, activation='softmax', name=name + '_out') def call(self, inputs, training=None): c = inputs for key, layer in self.layers_cluster.items(): c = layer(c) # final c: vector of n samples * n clusters with the probability of each sample being of each cluster if self.get_pred: y = inputs for key, layer in self.layers_pred.items(): y = layer(y) #return class predictions, cluster predictions return y,c else: return c def get_config(self): return {'n_clusters': self.n_clusters} class DomainAdversarialAE(AE): """ scMEDAL Fixed Effects subnetwork (FE) An extension of the autoencoder (AE) that integrates an adversarial classifier in its architecture to perform unsupervised domain adaptation. This class enables training of the AE such that the latent representation is invariant to domain shifts, making the model robust against changes in the data distribution. Attributes: in_shape (tuple): Shape of the input data. n_clusters (int): Number of clusters for adversarial classification. n_latent_dims (int): Dimensionality of the latent space. layer_units (list): Number of units in each dense layer of the encoder/decoder. last_activation (str): Activation function for the decoder's output layer. get_pred (bool): Whether the model includes latent space predictions. use_batch_norm (bool): Whether batch normalization is applied to the dense layers. n_pred (int): Number of prediction classes for the latent classifier (if `get_pred` is True). layer_units_latent_classifier (list): Number of units in each layer of the latent classifier (if `get_pred` is True). latent_classifier (Classifier): Classifier for generating predictions from the latent space (if `get_pred` is True). encoder (Encoder): Encoder component of the autoencoder. decoder (Decoder): Decoder component of the autoencoder. adversary (AdversarialClassifier): Adversarial classifier for unsupervised domain adaptation. """ def __init__(self, in_shape: tuple, n_clusters: int, n_latent_dims: int=2, layer_units: list=[9,7,5], last_activation: str="sigmoid", n_pred: int=10, layer_units_latent_classifier: list=[2], get_pred=False, use_batch_norm: bool=False, name='da_ae', **kwargs): """ Initialize the Domain Adversarial Autoencoder. Args: in_shape (tuple): Shape of the input data. n_clusters (int): Number of clusters for adversarial classification. n_latent_dims (int, optional): Dimensionality of the latent space. Default is 2. layer_units (list, optional): Number of units in each dense layer of the encoder/decoder. Default is [9, 7, 5]. last_activation (str, optional): Activation function for the decoder's output layer. Default is "sigmoid". n_pred (int, optional): Number of prediction classes for the latent classifier. Default is 10. layer_units_latent_classifier (list, optional): Number of units in each layer of the latent classifier. Default is [2]. get_pred (bool, optional): Whether to include latent space predictions. Default is False. use_batch_norm (bool, optional): Whether to apply batch normalization. Default is False. name (str, optional): Name of the model instance. Default is 'da_ae'. **kwargs: Additional arguments for the base class. """ super(AE, self).__init__(name=name, **kwargs) self.in_shape = in_shape self.n_clusters = n_clusters self.n_latent_dims = n_latent_dims self.layer_units = layer_units self.last_activation = last_activation self.get_pred = get_pred self.use_batch_norm = use_batch_norm if self.get_pred: self.n_pred = n_pred self.layer_units_latent_classifier = layer_units_latent_classifier #The latent classifier returns class predictions self.latent_classifier = Classifier(n_clusters=self.n_pred,layer_units = self.layer_units_latent_classifier) #autoencoder: encoder +decoder self.encoder = Encoder(n_latent_dims = n_latent_dims, layer_units=self.layer_units, return_layer_activations=True, use_batch_norm=self.use_batch_norm) encoder_layers_list = list(self.encoder.all_layers.values()) self.decoder = Decoder(in_shape=self.in_shape,encoder_layers = encoder_layers_list,layer_units = self.layer_units, last_activation = self.last_activation) #adversarial classifier self.adversary = AdversarialClassifier(n_clusters = self.n_clusters, n_latent_dims = self.n_latent_dims, layer_units=self.layer_units) def call(self, inputs,training=None): """ Forward pass through the Domain Adversarial Autoencoder. Args: inputs (tuple): Tuple containing the input data and cluster information. Returns: tuple: Reconstruction from the decoder and prediction from the adversarial classifier. """ x, clusters = inputs #print(x.shape) # encoder encoder_activations = self.encoder(x,training=training) # apply adversary to encoder activations (decoder shares weights with encoder) pred_cluster = self.adversary(encoder_activations) # latent space is the last activation layer latent = encoder_activations[-1] # decoder is applied to latent recon = self.decoder(latent,training=training) if self.get_pred: # classification pred_class = self.latent_classifier(latent) return (recon, pred_class, pred_cluster) else: return (recon, pred_cluster) def compile(self, loss_recon=tf.keras.losses.MeanSquaredError(), loss_multiclass=tf.keras.losses.CategoricalCrossentropy(), metric_multiclass=tf.keras.metrics.CategoricalAccuracy(name='acc'), opt_autoencoder=tf.keras.optimizers.Adam(lr=0.0001), opt_adversary=tf.keras.optimizers.Adam(lr=0.0001), loss_recon_weight=1.0, loss_gen_weight=0.05, loss_class_weight=0.01): """ Compile the model with specified losses, metrics, and optimizers. Args: loss_recon (tf loss): Reconstruction loss function. loss_multiclass (tf loss): multiclass loss function. It works for all multiclass tasks. metric_multiclass (tf metric): Metric for adversarial classifier performance. opt_autoencoder (tf optimizer): Optimizer for autoencoder. opt_adversary (tf optimizer): Optimizer for adversarial classifier. loss_recon_weight (float): Weight for the reconstruction loss. loss_gen_weight (float): Weight for the adversarial loss (generator part). loss_class_weight (float): Weight for the class loss. Only used if get_pred ==True. """ super().compile() self.loss_recon = loss_recon # adv and class are the same loss but I decided to use diff names self.loss_adv = loss_multiclass self.loss_class = loss_multiclass self.opt_autoencoder = opt_autoencoder self.opt_adversary = opt_adversary # track mean loss self.loss_recon_tracker = tf.keras.metrics.Mean(name='recon_loss') self.loss_adv_tracker = tf.keras.metrics.Mean(name='adv_loss') self.loss_total_tracker = tf.keras.metrics.Mean(name='total_loss') # define metrics self.metric_adv = metric_multiclass self.metric_class = metric_multiclass # define loss weights self.loss_recon_weight = loss_recon_weight self.loss_gen_weight = loss_gen_weight if self.get_pred: # define latent class loss, metric and weights self.metric_multiclass = metric_multiclass self.loss_class_weight = loss_class_weight self.loss_class_tracker = tf.keras.metrics.Mean(name='class_loss') @property def metrics(self): if self.get_pred: return [self.loss_recon_tracker, self.loss_class_tracker, self.loss_adv_tracker, self.loss_total_tracker, self.metric_adv, self.metric_class] else: return [self.loss_recon_tracker, self.loss_adv_tracker, self.loss_total_tracker, self.metric_adv] def train_step(self, data): """ Perform a training step for the model. Args: data (tuple): Tuple containing input data, target data, and optionally sample weights. Returns: dict: Dictionary containing values of tracked metrics. """ #load data x, clusters = data[0] if self.get_pred: _, labels = data[1] # else: # labels = None sample_weights = None if len(data) != 3 else data[2] #CHECK IF THE SHAPES ARE CORRECT assert x.shape[0] == clusters.shape[0], "Mismatch between x and clusters" if self.get_pred: assert x.shape[0] == labels.shape[0], "Mismatch between x and labels" #Train adversary encoder_outs = self.encoder(x, training=True) #calculate adv loss with tf.GradientTape() as gt: pred_cluster = self.adversary(encoder_outs) loss_adv = self.loss_adv(clusters, pred_cluster, sample_weight=sample_weights) #apply gradients grads_adv = gt.gradient(loss_adv, self.adversary.trainable_variables) # minimizing adv loss (remove comments) self.opt_adversary.apply_gradients(zip(grads_adv, self.adversary.trainable_variables)) # Update adversarial loss tracker self.metric_adv.update_state(clusters, pred_cluster) self.loss_adv_tracker.update_state(loss_adv) # Train autoencoder with tf.GradientTape(persistent=True) as gt2: #apply model outputs = self(inputs=(x, clusters), training=True) if self.get_pred: pred_recon, pred_class, pred_cluster = outputs #(+ pred class) else: pred_recon, pred_cluster = outputs #compute individual losses loss_recon = self.loss_recon(x, pred_recon, sample_weight=sample_weights) loss_adv = self.loss_adv(clusters, pred_cluster, sample_weight=sample_weights) if self.get_pred: loss_class = self.loss_class(labels, pred_class, sample_weight=sample_weights) #add class loss to total loss: (recon) - adv loss (gen) +class loss total_loss = (self.loss_recon_weight * loss_recon) \ + (self.loss_class_weight * loss_class) \ - (self.loss_gen_weight * loss_adv) else: #compute total ae loss: (recon) - adv loss (gen) total_loss = (self.loss_recon_weight * loss_recon)- (self.loss_gen_weight * loss_adv) if self.get_pred: # +latent classifier trainable vars lsWeights = self.encoder.trainable_variables + self.decoder.trainable_variables \ + self.latent_classifier.trainable_variables else: lsWeights = self.encoder.trainable_variables + self.decoder.trainable_variables #backpropagate grads_aec = gt2.gradient(total_loss, lsWeights) self.opt_autoencoder.apply_gradients(zip(grads_aec, lsWeights)) # Update loss trackers if self.get_pred: self.metric_class.update_state(labels, pred_class) self.loss_class_tracker.update_state(loss_class) self.loss_recon_tracker.update_state(loss_recon) self.loss_total_tracker.update_state(total_loss) return {m.name: m.result() for m in self.metrics} def test_step(self, data): """ Perform a testing (validation) step for the model. Args: data (tuple): Tuple containing input data and target data. Returns: dict: Dictionary containing values of tracked metrics. """ x, clusters = data[0] if self.get_pred: _, labels = data[1] # else: # labels = None #CHECK IF THE SHAPES ARE CORRECT assert x.shape[0] == clusters.shape[0], "Mismatch between x and clusters" if self.get_pred: assert x.shape[0] == labels.shape[0], "Mismatch between x and labels" # apply model outputs = self(inputs=(x, clusters), training=False) if self.get_pred: pred_recon, pred_class, pred_cluster = outputs #(+ pred class) else: pred_recon, pred_cluster = outputs # compute ind losses loss_recon = self.loss_recon(x, pred_recon) loss_adv = self.loss_adv(clusters, pred_cluster) #compute total loss if self.get_pred: loss_class = self.loss_class(labels, pred_class) #add class loss to total loss: (recon) - adv loss (gen) +class loss total_loss = (self.loss_recon_weight * loss_recon) \ + (self.loss_class_weight * loss_class) \ - (self.loss_gen_weight * loss_adv) else: total_loss = (self.loss_recon_weight * loss_recon)- (self.loss_gen_weight * loss_adv) #update metrics and losses self.metric_adv.update_state(clusters, pred_cluster) self.loss_recon_tracker.update_state(loss_recon) self.loss_adv_tracker.update_state(loss_adv) self.loss_total_tracker.update_state(total_loss) if self.get_pred: self.metric_class.update_state(labels, pred_class) self.loss_class_tracker.update_state(loss_class) return {m.name: m.result() for m in self.metrics} class RandomEffectEncoder(Encoder): """ RandomEffectEncoder: A specialized encoder that incorporates random effects with dense layers. Inherits from the provided Encoder class. This encoder is designed to model random effects by introducing specialized layers for handling them. Each dense layer is followed by a random effect layer and an activation layer. Attributes: n_latent_dims (int): Number of latent dimensions. layer_units (list): List containing the number of units for each dense layer. post_loc_init_scale (float): Initial scale for the location of the posterior distribution. prior_scale (float): Scale for the prior distribution. kl_weight (float): Weighting factor for the KullbackĀ–Leibler divergence. re_layers (dict): Dictionary containing random effect layers. act_layers (dict): Dictionary containing activation layers. layer_blocks (dict): Dictionary containing blocks of (dense, random effect, activation) layers. Args: n_latent_dims (int, optional): Number of latent dimensions. Defaults to 2. layer_units (list, optional): List containing the number of units for each dense layer. Defaults to [8]. post_loc_init_scale (float, optional): Initial scale for the location of the posterior distribution. Defaults to 0.1. prior_scale (float, optional): Scale for the prior distribution. Defaults to 0.25. kl_weight (float, optional): Weighting factor for the KullbackĀ–Leibler divergence. Defaults to 1e-5. name (str, optional): Name of the encoder. Defaults to 'encoder'. **kwargs: Additional keyword arguments. """ def __init__(self, n_latent_dims: int=2, layer_units: list=[8], post_loc_init_scale: float=0.1, prior_scale: float=0.25, kl_weight: float=1e-5, name = 'encoder', **kwargs): """ Initialize the RandomEffectEncoder. """ super(RandomEffectEncoder, self).__init__(n_latent_dims=n_latent_dims, layer_units=layer_units,name=name, **kwargs) #dictionary of random effect layers self.re_layers = {} #dictionary of activation layers self.act_layers = {} #Build blocks of (dense, RE, activation layers) self.layer_blocks = {} #dense blocks are inherited from Encoder class for key, layer in self.dense_blocks.items(): #layer i layer_i = key.split("_")[-1] #random effect layer self.re_layers["re_"+layer_i] = ClusterScaleBiasBlock(layer.units, post_loc_init_scale = post_loc_init_scale, prior_scale = prior_scale, kl_weight = kl_weight, name = name + '_re_'+layer_i) #act layer self.act_layers["act_"+layer_i] = Activation('selu') #add blocks of (dense, RE, activation layers) self.layer_blocks["block_"+layer_i] = (layer,self.re_layers["re_"+layer_i], self.act_layers["act_"+layer_i]) #define re_encoder_layers self.re_encoder_layers = {**self.layer_blocks, "dense_latent": self.dense_latent} def call(self, inputs, training=None): """ Forward pass for the RandomEffectEncoder. Args: inputs (tuple): A tuple containing two elements - the input data (x) and the random effects data (z). training (bool, optional): If in training mode or not. Defaults to None. Returns: tf.Tensor: Transformed input after passing through dense, random effect and activation layers. """ x, z = inputs # print("x.shape:",x.shape) # print("z.shape:",z.shape) for key, (dense, re, activation) in self.layer_blocks.items(): x = dense(x) x = re((x, z), training=training) x = activation(x) x = self.dense_latent(x) return x # def summary(self): # print("RandomEffectEncoder Summary:") # print(f"{'Layer':<20} {'Output Shape':<20} {'# Params':<10}") # for name, (dense, re, activation) in self.layer_blocks.items(): # # Assuming model has been built at least once so these methods can be accessed # print(f"{dense.name:<20} {str(dense.output_shape):<20} {dense.count_params():<10}") # print(f"{re.name:<20} {'-':<20} {re.count_params():<10}") # print(f"{activation.name:<20} {'-':<20} {'-':<10}") # print(f"{self.dense_latent.name:<20} {str(self.dense_latent.output_shape):<20} {self.dense_latent.count_params():<10}") class RandomEffectDecoder(Decoder): def __init__(self, in_shape: tuple, layer_units: list=[8], last_activation: str='sigmoid', post_loc_init_scale: float=0.1, prior_scale: float=0.25, kl_weight: float=1e-5, name = 'decoder', **kwargs): """ Initialize the RandomEffectDecoder. """ #I do not want tied weights in the RandomEffectDecoder super(RandomEffectDecoder, self).__init__(in_shape = in_shape, layer_units = layer_units, last_activation = last_activation, name = name, tied_weights = False, **kwargs) #dictionary of random effect layers self.re_layers = {} #dictionary of activation layers self.act_layers = {} #Build blocks of (dense, RE, activation layers) self.layer_blocks = {} #dense blocks are inherited from Encoder class for key, layer in self.all_layers.items(): #layer i layer_i = key.split("_")[-1] #random effect layer self.re_layers["re_"+layer_i] = ClusterScaleBiasBlock(layer.units, post_loc_init_scale = post_loc_init_scale, prior_scale = prior_scale, kl_weight = kl_weight, name = name + '_re_'+layer_i) if key == 'dense_out': #for the block that has the dense_out layer, the activation layer = last_activation self.act_layers["last_act"] = Activation(self.last_activation, name = name + '_act_'+self.last_activation) #add blocks of (dense, RE, activation layers) self.layer_blocks["block_"+layer_i] = (layer,self.re_layers["re_"+layer_i], self.act_layers["last_act"]) else: #all other activation layers are 'relu' self.act_layers["act_"+layer_i] = Activation('selu', name = name + '_act_'+layer_i) #add blocks of (dense, RE, activation layers) self.layer_blocks["block_"+layer_i] = (layer,self.re_layers["re_"+layer_i], self.act_layers["act_"+layer_i]) self.re_decoder_layers = self.layer_blocks def call(self, inputs, training=None): """ Forward pass for the RandomEffectDecoder. Args: inputs (tuple): A tuple containing two elements - the input data (x) and the random effects data (z). training (bool, optional): If in training mode or not. Defaults to None. Returns: tf.Tensor: Transformed input after passing through dense, random effect and activation layers. """ x, z = inputs for key, (dense, re, activation) in self.layer_blocks.items(): x = dense(x) x = re((x, z), training=training) x = activation(x) return x # def summary(self): # print("RandomEffectDecoder Summary:") # print(f"{'Layer':<20} {'Output Shape':<20} {'# Params':<10}") # for name, (dense, re, activation) in self.layer_blocks.items(): # # Assuming model has been built at least once so these methods can be accessed # print(f"{dense.name:<20} {str(dense.output_shape):<20} {dense.count_params():<10}") # print(f"{re.name:<20} {'-':<20} {re.count_params():<10}") # print(f"{activation.name:<20} {'-':<20} {'-':<10}") class DomainEnhancingAutoencoderClassifier(tf.keras.Model): """ scMEDAL Random Effects subnetwork (RE) Autoencoder model for classification and clustering of the batch effects. This model leverages an autoencoder structure with a domain-enhanced approach to perform classification and clustering tasks. It comprises an encoder (`RandomEffectEncoder`), a decoder (`RandomEffectDecoder`), and a batch classifier which operates in the latent space. The model can predict clusters or class labels based on the latent and reconstructed representations. Parameters: ------------ - in_shape (tuple): Input shape of the data. - n_clusters (int, optional): Number of clusters for classification. Default is 10. - n_latent_dims (int, optional): Dimensionality of the latent space. Default is 2. - layer_units (list, optional): Units for each layer in the autoencoder. Default is [10, 5]. - layer_units_classifier (list, optional): Units for each layer in the classifier. Default is [2]. - n_pred (int, optional): Number of prediction classes if `get_pred` is True. Default is 10. - last_activation (str, optional): Activation for the last layer of the autoencoder. Default is "sigmoid". - post_loc_init_scale (float, optional): Initial scale for the posterior's location. Default is 0.1. - prior_scale (float, optional): Scale for the prior distribution. Default is 0.25. - kl_weight (float, optional): Weight for KL divergence loss. Default is 1e-5. - get_pred (bool, optional): Predict class labels alongside clusters. Default is False. - get_recon_cluster (bool, optional): Retrieve cluster prediction from reconstruction. Default is False. - name (str, optional): Model's name. Default is "ae". Attributes: ------------ Various components of the model such as the encoder, decoder, and classifiers are stored as attributes. Methods: ------------ - call(inputs, training=None): Performs a forward pass of the model. - compile(...): Configures the model for training. - train_step(data): Defines a single training step for the model. - test_step(data): Defines a single test (or validation) step for the model. Note: The model is designed to handle input data as a tuple of (count matrix, clusters). If enabled (via `get_pred`), it can also take labels for supervised training. Outputs include the reconstructed data and the predictions based on latent and reconstructed representations. """ def __init__(self, in_shape: tuple, n_clusters: int=10, n_latent_dims: int = 2, layer_units: list = [10,5], layer_units_classifier:list = [2], n_pred: int = 10, last_activation: str = "sigmoid", post_loc_init_scale: float=0.1, prior_scale: float=0.25, kl_weight: float=1e-5, get_pred = False, get_recon_cluster = False, name='ae', **kwargs): super(DomainEnhancingAutoencoderClassifier, self).__init__(name=name, **kwargs) self.in_shape = in_shape self.n_clusters = n_clusters self.n_latent_dims = n_latent_dims self.layer_units = layer_units self.last_activation = last_activation self.get_pred = get_pred self.n_pred = n_pred self.layer_units_classifier = layer_units_classifier self.get_recon_cluster = get_recon_cluster # RE encoder self.re_encoder = RandomEffectEncoder(n_latent_dims=self.n_latent_dims, layer_units=self.layer_units, post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight) # RE decoder: weights not tied self.re_decoder = RandomEffectDecoder(in_shape=self.in_shape, layer_units = self.layer_units, last_activation = self.last_activation, post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight) # The latent classifier returns class predictions in addition to cluster predictions if get_pred =True self.re_latent_classifier = Classifier(n_clusters=self.n_clusters,layer_units = self.layer_units_classifier,n_pred = self.n_pred, get_pred = self.get_pred) # get cluster prediction from reconstruction if self.get_recon_cluster: self.re_recon_classifier = Classifier(n_clusters=self.n_clusters,layer_units = self.layer_units_classifier, get_pred = False) def call(self, inputs, training=None): if len(inputs) != 2: raise ValueError('Model inputs need to be a tuple of (count matrix, clusters)') x, z = inputs # Encode inputs latent = self.re_encoder((x, z), training=training) # Reconstruct image from latents recon = self.re_decoder((latent, z), training=training) output_dict = {'recon': recon} # Apply latent classifier latent_outs = self.re_latent_classifier(latent) if self.get_pred: # The latent classifier returns class predictions in addition to cluster predictions if get_pred=True pred_y, pred_c_latent = latent_outs output_dict['pred_y'] = pred_y output_dict['pred_c_latent'] = pred_c_latent else: pred_c_latent = latent_outs output_dict['pred_c_latent'] = pred_c_latent if self.get_recon_cluster: # Cluster predictions from reconstructed counts pred_c_recon = self.re_recon_classifier(recon) output_dict['pred_c_recon'] = pred_c_recon return output_dict def compile(self, loss_recon=tf.keras.losses.MeanSquaredError(), loss_multiclass=tf.keras.losses.CategoricalCrossentropy(), metric_multiclass=tf.keras.metrics.CategoricalAccuracy(name='categorical_accuracy'), optimizer=tf.keras.optimizers.Adam(lr=0.0001), loss_recon_weight=1.0, loss_class_weight=0.01, loss_latent_cluster_weight=0.001, loss_recon_cluster_weight=0.001): super().compile() self.loss_recon = loss_recon # the loss multiclass will be used for multiclass classification (cluster, class pred, etc) self.loss_multiclass = loss_multiclass self.optimizer = optimizer # loss weights self.loss_latent_cluster_weight = loss_latent_cluster_weight self.loss_recon_weight = loss_recon_weight # Loss trackers (mean loss across all the batches) self.loss_recon_tracker = tf.keras.metrics.Mean(name='recon_loss') self.loss_latent_cluster_tracker = tf.keras.metrics.Mean(name='la_clus_loss') self.loss_kl_tracker = tf.keras.metrics.Mean(name='kld') self.loss_total_tracker = tf.keras.metrics.Mean(name='total_loss') if self.get_pred: self.metric_multiclass = metric_multiclass self.loss_class_weight = loss_class_weight self.loss_class_tracker = tf.keras.metrics.Mean(name='class_loss') if self.get_recon_cluster: self.loss_recon_cluster_weight = loss_recon_cluster_weight self.loss_recon_cluster_tracker = tf.keras.metrics.Mean(name='recon_clus_loss') @property def metrics(self): metrics_list = [self.loss_recon_tracker, self.loss_latent_cluster_tracker, self.loss_kl_tracker, self.loss_total_tracker] if self.get_pred: metrics_list = metrics_list +[self.loss_class_tracker, self.metric_multiclass] elif self.get_recon_cluster: metrics_list = metrics_list +[self.loss_recon_cluster_tracker] return metrics_list def _compute_update_loss(self, loss_recon, loss_latent_cluster, loss_recon_cluster=None,loss_class = None, training=True): '''Compute total loss and update loss running means''' #update loss if (self.get_pred)&(loss_class is not None): self.loss_class_tracker.update_state(loss_class) if (self.get_recon_cluster)&(loss_recon_cluster is not None): self.loss_recon_cluster_tracker.update_state(loss_recon_cluster) self.loss_recon_tracker.update_state(loss_recon) self.loss_latent_cluster_tracker.update_state(loss_latent_cluster) if training: # The encoder and decoder have RandomEffect Layers, which inherit the properties of tpl.DenseVariational. # This layer adds the kld as regularization loss to the model. The regularizations are stored in model.losses. # Since there are more than one RElayers, we get the mean of all of them. kld = tf.reduce_mean(self.re_encoder.losses) + tf.reduce_mean(self.re_decoder.losses) self.loss_kl_tracker.update_state(kld) else: # KLD can't be computed at inference time because posteriors are simplified to # point estimates kld = 0 loss_total = (self.loss_recon_weight*loss_recon) + (self.loss_latent_cluster_weight * loss_latent_cluster)+kld if (self.get_pred)&(loss_class is not None): loss_total = loss_total + (self.loss_class_weight * loss_class) if (self.get_recon_cluster)&(loss_recon_cluster is not None): loss_total = loss_total + (self.loss_recon_cluster_weight * loss_recon_cluster) self.loss_total_tracker.update_state(loss_total) return loss_total def train_step(self, data): #missing to edit this part #load data x, clusters = data[0] if self.get_pred: _, labels = data[1] sample_weights = None if len(data) != 3 else data[2] # Train the rest of the model with tf.GradientTape() as gt: # Apply RE autoencoder: encoder + decoder outputs = self((x, clusters), training=True) recon = outputs['recon'] pred_c_latent = outputs['pred_c_latent'] if self.get_pred: pred_y = outputs['pred_y'] # Multiclass loss loss_class = self.loss_multiclass(labels, pred_y) else: loss_class = None if self.get_recon_cluster: pred_c_recon_1 = outputs['pred_c_recon'] loss_recon_cluster_1 = self.loss_multiclass(clusters, pred_c_recon_1) else: loss_recon_cluster_1 = None # mse loss loss_recon = self.loss_recon(x, recon) loss_latent_cluster = self.loss_multiclass(clusters, pred_c_latent) loss_total = self._compute_update_loss(loss_recon = loss_recon, loss_latent_cluster = loss_latent_cluster, loss_recon_cluster = loss_recon_cluster_1, loss_class = loss_class) # get trainable variables lsWeights = self.re_encoder.trainable_variables + self.re_decoder.trainable_variables # if the weight of loss_latent_cluster_weight>0, add it to the trainable variables if self.loss_latent_cluster_weight>0: lsWeights = lsWeights + self.re_latent_classifier.trainable_variables if self.get_recon_cluster: lsWeights = lsWeights + self.re_recon_classifier.trainable_variables # backpropagate grads = gt.gradient(loss_total, lsWeights) self.optimizer.apply_gradients(zip(grads, lsWeights)) if self.get_pred: # Update metrics self.metric_multiclass.update_state(labels, pred_y) return {m.name: m.result() for m in self.metrics} def test_step(self, data): #load data x, clusters = data[0] if self.get_pred: _, labels = data[1] sample_weights = None if len(data) != 3 else data[2] outputs = self((x, clusters), training=False) recon = outputs['recon'] pred_c_latent = outputs['pred_c_latent'] if self.get_pred: pred_y = outputs['pred_y'] # Multiclass loss loss_class = self.loss_multiclass(labels, pred_y) else: loss_class = None if self.get_recon_cluster: pred_c_recon_1 = outputs['pred_c_recon'] loss_recon_cluster_1 = self.loss_multiclass(clusters, pred_c_recon_1) else: loss_recon_cluster_1 = None loss_recon = self.loss_recon(x, recon) loss_latent_cluster = self.loss_multiclass(clusters, pred_c_latent) loss_total = self._compute_update_loss(loss_recon = loss_recon, loss_latent_cluster = loss_latent_cluster, loss_recon_cluster = loss_recon_cluster_1, loss_class = loss_class, training=False) if self.get_pred: # Update metrics self.metric_multiclass.update_state(labels, pred_y) return {m.name: m.result() for m in self.metrics} fe_latent, re_latent,z = inputs x = self.concat2subnets([fe_latent, re_latent]) # apply hidden layers for key, layer in self.dense_hidden_layers.items(): x = layer(x) # I will take me_latent after apply a dense layer to me_latent = fe_latent +re_latent. However, this layer may be optional me_latent = self.dense_me_latent(x) # me_pred_y is always done me_outputs = self.me_classifier((me_latent,z)) return me_outputs class MixedEffectsEncoder(tf.keras.layers.Layer): """ A TensorFlow Keras layer that concatenates fixed effects (FE) and random effects (RE) latent spaces, and optionally applies a random effects (RE) layer. The resulting mixed effects latent space is then processed through a series of dense hidden layers. Parameters: - n_latent_dims (int): The number of dimensions in the mixed effects latent space. - layer_units (list of int): The number of units in each dense hidden layer. - post_loc_init_scale (float): Initial scale for the location in the post-RE layer, used if an RE layer is added. - prior_scale (float): Scale of the prior in the RE layer, used if an RE layer is added. - kl_weight (float): Weight of the KL divergence in the loss, used if an RE layer is added. - add_re_2_meclass (bool): Determines whether to add an RE layer to the Mixed Effects Classifier. - name (str): Name of the layer. - **kwargs: Additional keyword arguments for the base Layer class. This encoder first concatenates the FE and RE latent spaces. It then processes the concatenated latent space through a series of dense hidden layers defined in `layer_units`. If `add_re_2_meclass` is True, an RE layer is applied after the dense hidden layers. The output is a mixed effects latent space that can be used for further processing or classification. The `call` method: Takes inputs `fe_latent`, `re_latent`, and `z`, and processes them through the encoder to produce the mixed effects latent space. If `add_re_2_meclass` is True, `z` is used in the RE layer. Inputs: - fe_latent: The latent representation of the fixed effects. - re_latent: The latent representation of the random effects. - z: Additional features or information, used if an RE layer is added. Returns: - me_latent: The resulting mixed effects latent space. """ def __init__(self, n_latent_dims: int = 2, layer_units: list = [10,5], post_loc_init_scale: float=0.1, prior_scale: float=0.25, kl_weight: float=1e-5, add_re_2_meclass = False, name='me_encoder', **kwargs): super(MixedEffectsEncoder, self).__init__(name=name, **kwargs) self.n_latent_dims = n_latent_dims self.layer_units = layer_units #add RE layer to ME classifier self.add_re_2_meclass = add_re_2_meclass self.concat2subnets = tf.keras.layers.Concatenate(axis=-1, name=name + 'concat_fe_re_latent') # define hidden layers self.dense_hidden_layers = {} for i, n_units in enumerate(self.layer_units): key_name = "dense_" + str(i) self.dense_hidden_layers[key_name] = Dense(units=n_units, activation="selu", name=key_name) if self.add_re_2_meclass: self.re_layer = ClusterScaleBiasBlock(self.n_latent_dims, post_loc_init_scale = post_loc_init_scale, prior_scale = prior_scale, kl_weight = kl_weight, name = name + '_re_layer') self.act = Activation('selu') self.dense_me_latent = Dense(units=self.n_latent_dims, activation="selu", name="dense_me_latent") def call(self, inputs, training=None): # fe_latent, re_latent,z = inputs # fe_latent = inputs["fe_latent"] # re_latent = inputs["re_latent"] fe_latent = inputs.get("fe_latent") re_latent = inputs.get("re_latent", None) # Only concatenate re_latent if it is not None if re_latent is not None: x = self.concat2subnets([fe_latent, re_latent]) else: x = fe_latent if self.add_re_2_meclass: z = inputs["z"] # x = self.concat2subnets([fe_latent, re_latent]) # apply hidden layers for key, layer in self.dense_hidden_layers.items(): x = layer(x) # Optional, add re layer if self.add_re_2_meclass: x = self.re_layer((x, z), training=training) x = self.act(x) me_latent = self.dense_me_latent(x) return me_latent class MixedEffectsModel(tf.keras.Model): """ MixedEffectsModel. It is a mixed effects classifier which processes inputs through a Mixed Effects Encoder and a dense output layer for classification. It's designed to handle both fixed effects (FE) and random effects (RE) latent spaces, making it suitable for scenarios where both fixed and random effects are considered. Parameters: - n_latent_dims (int): The number of dimensions in the mixed effects latent space created by the Mixed Effects Encoder. - layer_units (list of int): The number of units in each dense hidden layer within the Mixed Effects Encoder. - n_pred (int): The number of units in the final dense output layer, typically corresponding to the number of classes for classification. - post_loc_init_scale (float): Initial scale for the location in the post-RE layer within the Mixed Effects Encoder, used if an RE layer is added. - prior_scale (float): Scale of the prior in the RE layer within the Mixed Effects Encoder, used if an RE layer is added. - kl_weight (float): Weight of the KL divergence in the loss within the Mixed Effects Encoder, used if an RE layer is added. - add_re_2_meclass (bool): Determines whether to add an RE layer to the Mixed Effects Classifier within the Mixed Effects Encoder. - name (str): Name of the model. - **kwargs: Additional keyword arguments for the base Model class. The model encapsulates a Mixed Effects Encoder for processing the FE and RE latent spaces, followed by a dense output layer with softmax activation for classification. The `call` method: Processes the inputs through the Mixed Effects Encoder and then through the dense output layer. Inputs: - fe_latent: The latent representation of the fixed effects. - re_latent: The latent representation of the random effects. - z: Additional features or information, used in the RE layer of the Mixed Effects Encoder if `add_re_2_meclass` is True. Returns: - y: The classification output, with probabilities for each class. """ def __init__(self, n_latent_dims: int = 2, layer_units: list = [10,5], n_pred: int = 10, post_loc_init_scale: float=0.1, prior_scale: float=0.25, kl_weight: float=1e-5, add_re_2_meclass = False, name='mec', **kwargs): super(MixedEffectsModel, self).__init__(**kwargs) self.n_latent_dims = n_latent_dims self.layer_units = layer_units self.n_pred = n_pred # MixedEffectsmodule self.encoder = MixedEffectsEncoder(n_latent_dims = self.n_latent_dims, layer_units = self.layer_units, post_loc_init_scale = post_loc_init_scale, prior_scale = prior_scale, kl_weight = kl_weight, add_re_2_meclass = add_re_2_meclass, name = 'me_encoder', **kwargs) self.dense_out = Dense(self.n_pred, activation='softmax', name=name + '_out') def call(self, inputs,training=None): # fe_latent, re_latent,z = inputs me_latent = self.encoder(inputs,training=training) # dense out with softmax activation y = self.dense_out(me_latent) return y