import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Activation,BatchNormalization
import tensorflow.keras.layers as tkl
from scMEDAL.models.random_effects import ClusterScaleBiasBlock
from tensorflow.keras import layers
from collections.abc import Iterable
"""
Author: Aixa Andrade with collaboration of Son Nguyen.
Code inspired in the original ARMED's convolutional autoencoder code written by Kevin Nguyen for the melanoma experiment.
Some snippets of code are borrowed as they are from Kevin Nguyen et al 2023 (ARMED paper (2023)).
This code uses custom Dense layers for building custom scMEDAL vector autoencoders.
"""
class TiedDenseTranspose(tf.keras.layers.Layer):
"""
A tied dense transpose layer that shares weights with a source dense layer.
Attributes:
source_layer (tf.keras.layers.Dense): Source dense layer to tie weights with.
activation (tf.keras.activations): Activation function for the layer.
units (int): Number of units for the layer.
kernel (tf.Variable): Shared weights with the source layer.
bias_t (tf.Variable): Bias for the layer.
"""
# Inspired in Medium article of Building an autoencoder with tied weights in keras by Laurence Mayrand-Provencher (2019)
# https://medium.com/@lmayrandprovencher/building-an-autoencoder-with-tied-weights-in-keras-c4a559c529a2
def __init__(self, source_layer: tf.keras.layers.Dense, units, activation=None, **kwargs):
"""
Initialize the TiedDenseTranspose layer.
Args:
source_layer (tf.keras.layers.Dense): Source dense layer to tie weights with.
units (int): Number of units for the layer.
activation (str, optional): Activation function to use. Defaults to None.
**kwargs: Additional keyword arguments.
"""
self.source_layer = source_layer
self.activation = tf.keras.activations.get(activation)
self.units = units
super().__init__(**kwargs)
def build(self, batch_input_shape):
"""Build the layer weights."""
# it only shares weights but not biases
self.kernel = self.source_layer.kernel
# initializes bias as zeros
self.bias_t = self.add_weight(name='bias_t',
shape=(self.units,),
initializer="zeros")
super().build(batch_input_shape)
def call(self, inputs):
"""Apply the layer operations on the input tensor."""
return self.activation(tf.matmul(inputs, self.kernel, transpose_b=True) + self.bias_t)
class Encoder(tf.keras.Model):
"""
Encoder Layer for Neural Networks with optional batch normalization.
Attributes:
n_latent_dim (int): Number of latent dimensions.
layer_units (list): List of units for each dense layer.
return_layer_activations (bool): Flag to determine if layer activations should be returned.
return_encoder_layers (bool): Flag to determine if encoder layers should be returned.
layers (dict): Dictionary containing all the layers.
dense_blocks (dict): Dictionary containing dense layers and batch normalization layers.
"""
def __init__(self,
n_latent_dims: int=2,
layer_units: list=[9, 7, 5],
return_layer_activations: bool=False,
return_encoder_layers: bool=False,
use_batch_norm: bool=False, # Flag to determine if batch normalization should be used
name='encoder',
**kwargs):
super(Encoder, self).__init__(name=name, **kwargs)
self.n_latent_dim = n_latent_dims
self.layer_units = layer_units
self.return_layer_activations = return_layer_activations
self.return_encoder_layers = return_encoder_layers
self.use_batch_norm = use_batch_norm
# Create dictionaries for the blocks and layers
self.dense_blocks = {}
self.all_layers = {}
# Fill the dictionaries using a for loop
for i, n_units in enumerate(self.layer_units):
key_name = "dense_" + str(i)
dense_layer = Dense(units=n_units, activation=None, name=key_name) # activation is None if batch norm is used
self.dense_blocks[key_name] = dense_layer
self.all_layers[key_name] = [dense_layer]
if self.use_batch_norm:
bn_key_name = "batch_norm_" + str(i)
bn_layer = BatchNormalization(name=bn_key_name)
self.all_layers[key_name].append(bn_layer)
# Add activation layer separately if using batch norm
activation_layer = tf.keras.layers.Activation('selu')
self.all_layers[key_name].append(activation_layer)
# Define the latent layer
self.dense_latent = Dense(units=self.n_latent_dim, activation="selu", name="dense_latent")
self.all_layers["dense_latent"] = [self.dense_latent]
def call(self, inputs, training=None):
x = inputs
layer_activations = []
# Iterate through the layers using the dictionary
for key, layers in self.all_layers.items():
for layer in layers:
x = layer(x, training=training) # ensure to pass training parameter for batch normalization
layer_activations.append(x)
if self.return_layer_activations:
return layer_activations
elif self.return_encoder_layers:
return self.all_layers, x
else:
return x
class Decoder(tf.keras.Model):
"""
Decoder Layer for Neural Networks.
The Decoder layers can be Tied with the Encoder layers if encoder layers are provided and if the tied_weights = True.
Attributes:
encoder_dense_layers (list): List of encoder dense layers to tie weights with.
in_shape (tuple): Input shape for the autoencoder (encoder input shape). Input shape encoder = output shape decoder
layer_units (list): List of units for each dense layer.
last_activation (str): Last activation function for the decoder.
layers (dict): Dictionary containing all the layers.
"""
def __init__(self,
in_shape: tuple,
encoder_layers: list = [],
layer_units: list=[9,7,5],
last_activation: str='sigmoid',
name='decoder',
tied_weights = True,
**kwargs):
"""
Initialize the Decoder.
Args:
encoder_layers (list, optional): List of encoder layers to tie weights with. If you want a TiedDecoder, you have to provide the encoder_layers. Defaults to empty list.
in_shape (tuple): Input shape for the autoencoder (encoder input shape). Input shape encoder = output shape decoder
layer_units (list, optional): List containing the number of units for each dense layer. Defaults to [784, 392].
last_activation (str, optional): Last activation function for the decoder. Defaults to "sigmoid".
name (str, optional): Name of the layer. Defaults to 'decoder'.
tied_weights (bool, optional): If True, the layers of the Decoder are Tied with the Encoder. Else: The layers of the Decoder are Dense. Defaults to True.
**kwargs: Additional keyword arguments.
"""
super(Decoder, self).__init__(name=name, **kwargs)
self.in_shape = in_shape
self.layer_units = layer_units
self.last_activation = last_activation
self.all_layers = {}
self.tied_weights = tied_weights
if (self.tied_weights == True )& (len(encoder_layers)>0):
#If tied weights = True --> decoder layers are tied with the encoder layers
#print("encoder layers",encoder_layers)
# get encoder dense layers
# encoder_dense_layers = [layer for layer in encoder_layers if "dense" in layer.name]
# def is_iterable(obj):
# """ Check if the object is iterable but not a string """
# return isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes))
def is_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes))
# Using a nested list comprehension to handle both nested and flat list scenarios
encoder_dense_layers = [layer for item in encoder_layers
for layer in (item if is_iterable(item) else [item])
if "dense" in layer.name]
self.encoder_dense_layers = encoder_dense_layers
# build the decoder reverse looping through the encoder layers
for n_units, e_layer in zip(self.layer_units[::-1], self.encoder_dense_layers[1:][::-1]):
key_name = e_layer.name + "_t"
self.all_layers[key_name] = TiedDenseTranspose(source_layer=e_layer, units=n_units, activation="selu", name=key_name)
# out decoder: out layer shares weights with encoder first layer
# defining layer with last activation
# the last activation is sigmoid to make sure the values are between zero and one
key_name = "dense_out"
self.all_layers[key_name] = TiedDenseTranspose(source_layer=self.encoder_dense_layers[0], units=self.in_shape[-1], activation=self.last_activation, name=key_name)
else:
#If tied weights = False --> decoder layers are Dense layers
# build the decoder reverse looping through the layer units
for i,n_units in enumerate(self.layer_units[::-1]):
key_name = "dense_"+str(len(self.layer_units)-i)
self.all_layers[key_name] = Dense(units=n_units, activation="selu", name=key_name)
# the last activation is sigmoid to make sure the values are between zero and one
key_name = "dense_out"
self.all_layers[key_name] = Dense(units=self.in_shape[-1], activation=self.last_activation, name=key_name)
def call(self, inputs, training=None):
"""
Call the decoder layer with input data.
Args:
inputs (tf.Tensor): Input tensor data.
training (bool, optional): If in training mode or not. Defaults to None.
Returns:
tf.Tensor: Processed output tensor.
"""
x = inputs
# apply transposed dense layers (decoder)
for key, layer in self.all_layers.items():
#print(layer.name)
#x = layer(x)
x = layer(x, training=training)
return x
class AE(tf.keras.Model):
"""
Autoencoder (AE) Model with tied weights.
Attributes:
in_shape (tuple): Input shape for the AE.
layer_units (list): List of units for each dense layer in the encoder.
n_latent_dims (int): Number of latent dimensions for the encoder.
last_activation (str): Last activation function for the decoder.
return_layer_activations (bool): Whether to return layer activations from the encoder.
encoder (Encoder): Encoder part of the AE.
decoder (Decoder): Decoder part of the AE, it has Tied weights with the Encoder.
"""
def __init__(self,
in_shape: tuple,
n_latent_dims: int = 2,
layer_units: list = [9,7,5],
last_activation: str = "sigmoid",
return_layer_activations: bool = False,
use_batch_norm: bool=False,
name='ae',
**kwargs):
"""
Initialize the AE model.
Args:
in_shape (tuple): Input shape for the AE.
n_latent_dims (int, optional): Number of latent dimensions for the encoder. Defaults to 64.
layer_units (list, optional): List containing the number of units for each dense layer in the encoder. Defaults to [784, 392].
last_activation (str, optional): Last activation function for the decoder. Defaults to "sigmoid".
return_layer_activations (bool, optional): Whether to return layer activations from the encoder. Defaults to False.
name (str, optional): Name of the model. Defaults to 'ae'.
**kwargs: Additional keyword arguments.
"""
super(AE, self).__init__(name=name, **kwargs)
self.in_shape = in_shape
self.layer_units = layer_units
self.n_latent_dims = n_latent_dims
self.last_activation = last_activation
self.return_layer_activations = return_layer_activations
self.use_batch_norm = use_batch_norm
self.encoder = Encoder(n_latent_dims=n_latent_dims,
layer_units=layer_units,
return_layer_activations=self.return_layer_activations,
use_batch_norm=self.use_batch_norm)
# Assuming the Encoder class returns a dictionary for its layers attribute
encoder_layers_list = list(self.encoder.all_layers.values())
self.decoder = Decoder(in_shape=self.in_shape, encoder_layers=encoder_layers_list,
layer_units=self.layer_units,
last_activation=self.last_activation)
def call(self, inputs, training=None):
"""
Call the AE model with input data.
Args:
inputs (tf.Tensor): Input tensor data.
training (bool, optional): If in training mode or not. Defaults to None.
Returns:
tf.Tensor: Processed output tensor.
"""
# Get the encoder output. If return_layer_activations is True,
# the encoder returns all layer activations, else just the latent representation.
encoder_output = self.encoder(inputs, training=training)
# Determine the latent representation based on return_layer_activations flag
latent = encoder_output[-1] if self.return_layer_activations else encoder_output
out = self.decoder(latent, training=training)
return out
class AEC(tf.keras.Model):
"""
An autoencoder-based classifier model built using TensorFlow's Keras API.
This model is a combination of an autoencoder for unsupervised learning and a classifier for supervised learning. It is designed to work with input data in the specified input shape, compress it into a latent space using an encoder, and then reconstruct the input from this compressed representation using a decoder. Additionally, it uses the latent representation for classification purposes.
The autoencoder part of the model is a 'tied' autoencoder, meaning that the weights of the encoder are tied to the weights of the decoder. This type of architecture can be beneficial for certain types of data compression and reconstruction tasks.
Parameters:
in_shape (tuple): The shape of the input data.
n_latent_dims (int, optional): The number of dimensions for the latent space representation. Default is 2.
layer_units (list, optional): The number of units in each layer of the encoder (and by extension, the decoder). Default is [9, 7, 5].
last_activation (str, optional): The activation function to be used in the last layer of the autoencoder. Default is 'sigmoid'.
return_layer_activations (bool, optional): Flag to determine whether the encoder should return all layer activations or just the final latent representation. Default is False.
n_pred (int, optional): The number of prediction classes for the classifier. Default is 20.
layer_units_latent_classifier (list, optional): The number of units in each layer of the classifier. Default is [2].
name (str, optional): Name of the model. Default is 'ae_class'.
The model has three main components:
- An encoder that reduces the input to a lower-dimensional latent space.
- A decoder that reconstructs the input from the latent space.
- A classifier that uses the latent space representation for classification tasks.
The `call` method of the model takes in input data and optionally a training flag and returns a dictionary with two keys: 'reconstruction_output' for the output of the autoencoder, and 'classification_output' for the output of the classifier.
Example:
model = AE_classifier(in_shape=(28, 28, 1))
# For training or inference
output = model(data)
"""
def __init__(self,
in_shape: tuple,
n_latent_dims: int = 2,
layer_units: list = [9,7,5],
last_activation: str = "sigmoid",
return_layer_activations = False,
n_pred=20,
layer_units_latent_classifier=[2],
use_batch_norm: bool=False,
name='aec',
**kwargs):
super(AEC, self).__init__(name=name, **kwargs)
self.in_shape = in_shape
self.layer_units = layer_units
self.n_latent_dims = n_latent_dims
self.last_activation = last_activation
self.return_layer_activations = return_layer_activations
self.n_pred = n_pred
self.layer_units_latent_classifier = layer_units_latent_classifier
self.use_batch_norm = use_batch_norm
self.encoder = Encoder(n_latent_dims=n_latent_dims,
layer_units=layer_units,
return_layer_activations=self.return_layer_activations,
use_batch_norm=self.use_batch_norm)
# Assuming the Encoder class returns a dictionary for its layers attribute
encoder_layers_list = list(self.encoder.all_layers.values())
# Tied AE. ENCODER WEIGHTS = DECODER
self.decoder = Decoder(in_shape=self.in_shape, encoder_layers=encoder_layers_list,
layer_units=self.layer_units,
last_activation=self.last_activation)
self.classifier = Classifier(n_clusters=self.n_pred,layer_units = self.layer_units_latent_classifier)
def call(self, inputs, training=None):
# Get the encoder output. If return_layer_activations is True,
# the encoder returns all layer activations, else just the latent representation.
#print("model inputs shape",inputs.shape)
encoder_output = self.encoder(inputs, training=training)
# Determine the latent representation based on return_layer_activations flag
latent = encoder_output[-1] if self.return_layer_activations else encoder_output
# Pass the latent representation through the decoder and classifier
recon = self.decoder(latent, training=training)
classification = self.classifier(latent)
#print("recon pred shape",recon.shape,"class pred shape",classification.shape)
return {'reconstruction_output': recon, 'classification_output': classification}
class AdversarialClassifier(tkl.Layer):
def __init__(self,
n_clusters: int,
n_latent_dims: int=2,
layer_units: list=[5, 4],
name: str='adversary',
**kwargs):
"""Adversarial classifier.
Args:
n_clusters (int): number of clusters (classes)
layer_units (list, optional): Neurons in each layer. Can be a list of any
length. Defaults to [8, 8, 8].
name (str, optional): Model name. Defaults to 'adversary'.
"""
super(AdversarialClassifier, self).__init__(name=name, **kwargs)
self.n_clusters = n_clusters
self.layer_units = layer_units
self.all_layers = []
for iLayer, neurons in enumerate(layer_units):
self.all_layers += [tkl.Dense(neurons,
activation='relu',
name=name + '_dense' + str(iLayer))]
self.all_layers += [tkl.Dense(self.n_clusters , activation='softmax', name=name + '_dense_out')]
def call(self, inputs):
if type(inputs) is list:
inputs = tf.concat(inputs, axis=-1)
x = inputs
for layer in self.all_layers:
x = layer(x)
return x
def get_config(self):
return {'n_clusters': self.n_clusters,
'layer_units': self.layer_units}
def get_config(self):
return {'n_clusters': self.n_clusters,
'layer_units': self.layer_units}
class Classifier(tf.keras.layers.Layer):
"""
A custom Keras Layer for classification tasks.
This layer implements a classifier with a user-defined number of dense layers
followed by an output layer for clustering (can be used to predict donors/batches). Optionally, it can also implement
another classifier subnet for a second prediction (can be used to predict celltypes).
Attributes:
-----------
layer_units : list
List of integers specifying the number of units in each dense layer.
layers_cluster : dict
Dictionary containing dense layers for clustering.
n_clusters : int
Number of clusters for the classification task.
n_pred : int
Number of predictions.
get_pred : bool
Flag to determine if prediction subnet should be built.
layers_pred : dict
Dictionary containing dense layers for prediction, built only if get_pred is True.
Methods:
--------
call(inputs, training=None):
Perform the forward pass for the clustering and optionally for prediction.
get_config():
Returns a dictionary containing the configuration of the classifier (i.e., n_clusters).
Parameters:
-----------
n_clusters : int
Number of clusters for the classification task.
layer_units : list, optional
List of integers specifying the number of units in each dense layer. Defaults to [32, 16].
n_pred : int, optional
Number of predictions, only used if get_pred is True. Defaults to 4.
get_pred : bool, optional
Flag to determine if a subnet for predictions should be built. Defaults to False.
name : str, optional
Name of the layer. Defaults to 'latent_classifier'.
**kwargs :
Additional keyword arguments inherited from tf.keras.layers.Layer.
"""
def __init__(self,
n_clusters: int,
layer_units: list=[2],
n_pred: int = 10,
get_pred = False,
name='latent_classifier',
**kwargs):
super(Classifier, self).__init__(name=name, **kwargs)
self.layer_units = layer_units
self.layers_cluster = {}
self.n_clusters = n_clusters
self.n_pred = n_pred
self.get_pred = get_pred
if self.get_pred:
self.layers_pred = {}
# Fill the dictionaries using a for loop
for i, n_units in enumerate(self.layer_units):
key_name = "dense_" + str(i)
self.layers_cluster[key_name] = Dense(units=n_units, activation="relu", name=key_name)
# if get_pred ==True, build dense subnet to get_predictions
if self.get_pred:
self.layers_pred[key_name] = Dense(units=n_units, activation="relu", name=key_name)
#This layer predicts the number of clusters
self.layers_cluster["dense_out"] = Dense(self.n_clusters, activation='softmax', name=name + '_out')
if self.get_pred:
#if get_pred ==True: use softmax to pred the classes
self.layers_pred["dense_out"] = Dense(self.n_pred, activation='softmax', name=name + '_out')
def call(self, inputs, training=None):
c = inputs
for key, layer in self.layers_cluster.items():
c = layer(c)
# final c: vector of n samples * n clusters with the probability of each sample being of each cluster
if self.get_pred:
y = inputs
for key, layer in self.layers_pred.items():
y = layer(y)
#return class predictions, cluster predictions
return y,c
else:
return c
def get_config(self):
return {'n_clusters': self.n_clusters}
class DomainAdversarialAE(AE):
"""
scMEDAL Fixed Effects subnetwork (FE)
An extension of the autoencoder (AE) that integrates an adversarial
classifier in its architecture to perform unsupervised domain adaptation.
This class enables training of the AE such that the latent representation
is invariant to domain shifts, making the model robust against changes in
the data distribution.
Attributes:
in_shape (tuple): Shape of the input data.
n_clusters (int): Number of clusters for adversarial classification.
n_latent_dims (int): Dimensionality of the latent space.
layer_units (list): Number of units in each dense layer of the encoder/decoder.
last_activation (str): Activation function for the decoder's output layer.
get_pred (bool): Whether the model includes latent space predictions.
use_batch_norm (bool): Whether batch normalization is applied to the dense layers.
n_pred (int): Number of prediction classes for the latent classifier (if `get_pred` is True).
layer_units_latent_classifier (list): Number of units in each layer of the latent classifier (if `get_pred` is True).
latent_classifier (Classifier): Classifier for generating predictions from the latent space (if `get_pred` is True).
encoder (Encoder): Encoder component of the autoencoder.
decoder (Decoder): Decoder component of the autoencoder.
adversary (AdversarialClassifier): Adversarial classifier for unsupervised domain adaptation.
"""
def __init__(self, in_shape: tuple, n_clusters: int,
n_latent_dims: int=2,
layer_units: list=[9,7,5],
last_activation: str="sigmoid",
n_pred: int=10,
layer_units_latent_classifier: list=[2],
get_pred=False,
use_batch_norm: bool=False,
name='da_ae',
**kwargs):
"""
Initialize the Domain Adversarial Autoencoder.
Args:
in_shape (tuple): Shape of the input data.
n_clusters (int): Number of clusters for adversarial classification.
n_latent_dims (int, optional): Dimensionality of the latent space. Default is 2.
layer_units (list, optional): Number of units in each dense layer of the encoder/decoder. Default is [9, 7, 5].
last_activation (str, optional): Activation function for the decoder's output layer. Default is "sigmoid".
n_pred (int, optional): Number of prediction classes for the latent classifier. Default is 10.
layer_units_latent_classifier (list, optional): Number of units in each layer of the latent classifier. Default is [2].
get_pred (bool, optional): Whether to include latent space predictions. Default is False.
use_batch_norm (bool, optional): Whether to apply batch normalization. Default is False.
name (str, optional): Name of the model instance. Default is 'da_ae'.
**kwargs: Additional arguments for the base class.
"""
super(AE, self).__init__(name=name, **kwargs)
self.in_shape = in_shape
self.n_clusters = n_clusters
self.n_latent_dims = n_latent_dims
self.layer_units = layer_units
self.last_activation = last_activation
self.get_pred = get_pred
self.use_batch_norm = use_batch_norm
if self.get_pred:
self.n_pred = n_pred
self.layer_units_latent_classifier = layer_units_latent_classifier
#The latent classifier returns class predictions
self.latent_classifier = Classifier(n_clusters=self.n_pred,layer_units = self.layer_units_latent_classifier)
#autoencoder: encoder +decoder
self.encoder = Encoder(n_latent_dims = n_latent_dims,
layer_units=self.layer_units,
return_layer_activations=True,
use_batch_norm=self.use_batch_norm)
encoder_layers_list = list(self.encoder.all_layers.values())
self.decoder = Decoder(in_shape=self.in_shape,encoder_layers = encoder_layers_list,layer_units = self.layer_units, last_activation = self.last_activation)
#adversarial classifier
self.adversary = AdversarialClassifier(n_clusters = self.n_clusters,
n_latent_dims = self.n_latent_dims,
layer_units=self.layer_units)
def call(self, inputs,training=None):
"""
Forward pass through the Domain Adversarial Autoencoder.
Args:
inputs (tuple): Tuple containing the input data and cluster information.
Returns:
tuple: Reconstruction from the decoder and prediction from the adversarial classifier.
"""
x, clusters = inputs
#print(x.shape)
# encoder
encoder_activations = self.encoder(x,training=training)
# apply adversary to encoder activations (decoder shares weights with encoder)
pred_cluster = self.adversary(encoder_activations)
# latent space is the last activation layer
latent = encoder_activations[-1]
# decoder is applied to latent
recon = self.decoder(latent,training=training)
if self.get_pred:
# classification
pred_class = self.latent_classifier(latent)
return (recon, pred_class, pred_cluster)
else:
return (recon, pred_cluster)
def compile(self,
loss_recon=tf.keras.losses.MeanSquaredError(),
loss_multiclass=tf.keras.losses.CategoricalCrossentropy(),
metric_multiclass=tf.keras.metrics.CategoricalAccuracy(name='acc'),
opt_autoencoder=tf.keras.optimizers.Adam(lr=0.0001),
opt_adversary=tf.keras.optimizers.Adam(lr=0.0001),
loss_recon_weight=1.0,
loss_gen_weight=0.05,
loss_class_weight=0.01):
"""
Compile the model with specified losses, metrics, and optimizers.
Args:
loss_recon (tf loss): Reconstruction loss function.
loss_multiclass (tf loss): multiclass loss function. It works for all multiclass tasks.
metric_multiclass (tf metric): Metric for adversarial classifier performance.
opt_autoencoder (tf optimizer): Optimizer for autoencoder.
opt_adversary (tf optimizer): Optimizer for adversarial classifier.
loss_recon_weight (float): Weight for the reconstruction loss.
loss_gen_weight (float): Weight for the adversarial loss (generator part).
loss_class_weight (float): Weight for the class loss. Only used if get_pred ==True.
"""
super().compile()
self.loss_recon = loss_recon
# adv and class are the same loss but I decided to use diff names
self.loss_adv = loss_multiclass
self.loss_class = loss_multiclass
self.opt_autoencoder = opt_autoencoder
self.opt_adversary = opt_adversary
# track mean loss
self.loss_recon_tracker = tf.keras.metrics.Mean(name='recon_loss')
self.loss_adv_tracker = tf.keras.metrics.Mean(name='adv_loss')
self.loss_total_tracker = tf.keras.metrics.Mean(name='total_loss')
# define metrics
self.metric_adv = metric_multiclass
self.metric_class = metric_multiclass
# define loss weights
self.loss_recon_weight = loss_recon_weight
self.loss_gen_weight = loss_gen_weight
if self.get_pred: # define latent class loss, metric and weights
self.metric_multiclass = metric_multiclass
self.loss_class_weight = loss_class_weight
self.loss_class_tracker = tf.keras.metrics.Mean(name='class_loss')
@property
def metrics(self):
if self.get_pred:
return [self.loss_recon_tracker,
self.loss_class_tracker,
self.loss_adv_tracker,
self.loss_total_tracker,
self.metric_adv,
self.metric_class]
else:
return [self.loss_recon_tracker,
self.loss_adv_tracker,
self.loss_total_tracker,
self.metric_adv]
def train_step(self, data):
"""
Perform a training step for the model.
Args:
data (tuple): Tuple containing input data, target data, and optionally sample weights.
Returns:
dict: Dictionary containing values of tracked metrics.
"""
#load data
x, clusters = data[0]
if self.get_pred:
_, labels = data[1]
# else:
# labels = None
sample_weights = None if len(data) != 3 else data[2]
#CHECK IF THE SHAPES ARE CORRECT
assert x.shape[0] == clusters.shape[0], "Mismatch between x and clusters"
if self.get_pred:
assert x.shape[0] == labels.shape[0], "Mismatch between x and labels"
#Train adversary
encoder_outs = self.encoder(x, training=True)
#calculate adv loss
with tf.GradientTape() as gt:
pred_cluster = self.adversary(encoder_outs)
loss_adv = self.loss_adv(clusters, pred_cluster, sample_weight=sample_weights)
#apply gradients
grads_adv = gt.gradient(loss_adv, self.adversary.trainable_variables)
# minimizing adv loss (remove comments)
self.opt_adversary.apply_gradients(zip(grads_adv, self.adversary.trainable_variables))
# Update adversarial loss tracker
self.metric_adv.update_state(clusters, pred_cluster)
self.loss_adv_tracker.update_state(loss_adv)
# Train autoencoder
with tf.GradientTape(persistent=True) as gt2:
#apply model
outputs = self(inputs=(x, clusters), training=True)
if self.get_pred:
pred_recon, pred_class, pred_cluster = outputs #(+ pred class)
else:
pred_recon, pred_cluster = outputs
#compute individual losses
loss_recon = self.loss_recon(x, pred_recon, sample_weight=sample_weights)
loss_adv = self.loss_adv(clusters, pred_cluster, sample_weight=sample_weights)
if self.get_pred:
loss_class = self.loss_class(labels, pred_class, sample_weight=sample_weights)
#add class loss to total loss: (recon) - adv loss (gen) +class loss
total_loss = (self.loss_recon_weight * loss_recon) \
+ (self.loss_class_weight * loss_class) \
- (self.loss_gen_weight * loss_adv)
else:
#compute total ae loss: (recon) - adv loss (gen)
total_loss = (self.loss_recon_weight * loss_recon)- (self.loss_gen_weight * loss_adv)
if self.get_pred: # +latent classifier trainable vars
lsWeights = self.encoder.trainable_variables + self.decoder.trainable_variables \
+ self.latent_classifier.trainable_variables
else:
lsWeights = self.encoder.trainable_variables + self.decoder.trainable_variables
#backpropagate
grads_aec = gt2.gradient(total_loss, lsWeights)
self.opt_autoencoder.apply_gradients(zip(grads_aec, lsWeights))
# Update loss trackers
if self.get_pred:
self.metric_class.update_state(labels, pred_class)
self.loss_class_tracker.update_state(loss_class)
self.loss_recon_tracker.update_state(loss_recon)
self.loss_total_tracker.update_state(total_loss)
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
"""
Perform a testing (validation) step for the model.
Args:
data (tuple): Tuple containing input data and target data.
Returns:
dict: Dictionary containing values of tracked metrics.
"""
x, clusters = data[0]
if self.get_pred:
_, labels = data[1]
# else:
# labels = None
#CHECK IF THE SHAPES ARE CORRECT
assert x.shape[0] == clusters.shape[0], "Mismatch between x and clusters"
if self.get_pred:
assert x.shape[0] == labels.shape[0], "Mismatch between x and labels"
# apply model
outputs = self(inputs=(x, clusters), training=False)
if self.get_pred:
pred_recon, pred_class, pred_cluster = outputs #(+ pred class)
else:
pred_recon, pred_cluster = outputs
# compute ind losses
loss_recon = self.loss_recon(x, pred_recon)
loss_adv = self.loss_adv(clusters, pred_cluster)
#compute total loss
if self.get_pred:
loss_class = self.loss_class(labels, pred_class)
#add class loss to total loss: (recon) - adv loss (gen) +class loss
total_loss = (self.loss_recon_weight * loss_recon) \
+ (self.loss_class_weight * loss_class) \
- (self.loss_gen_weight * loss_adv)
else:
total_loss = (self.loss_recon_weight * loss_recon)- (self.loss_gen_weight * loss_adv)
#update metrics and losses
self.metric_adv.update_state(clusters, pred_cluster)
self.loss_recon_tracker.update_state(loss_recon)
self.loss_adv_tracker.update_state(loss_adv)
self.loss_total_tracker.update_state(total_loss)
if self.get_pred:
self.metric_class.update_state(labels, pred_class)
self.loss_class_tracker.update_state(loss_class)
return {m.name: m.result() for m in self.metrics}
class RandomEffectEncoder(Encoder):
"""
RandomEffectEncoder: A specialized encoder that incorporates random effects with dense layers.
Inherits from the provided Encoder class. This encoder is designed to model random effects by
introducing specialized layers for handling them. Each dense layer is followed by a random effect layer
and an activation layer.
Attributes:
n_latent_dims (int): Number of latent dimensions.
layer_units (list): List containing the number of units for each dense layer.
post_loc_init_scale (float): Initial scale for the location of the posterior distribution.
prior_scale (float): Scale for the prior distribution.
kl_weight (float): Weighting factor for the KullbackLeibler divergence.
re_layers (dict): Dictionary containing random effect layers.
act_layers (dict): Dictionary containing activation layers.
layer_blocks (dict): Dictionary containing blocks of (dense, random effect, activation) layers.
Args:
n_latent_dims (int, optional): Number of latent dimensions. Defaults to 2.
layer_units (list, optional): List containing the number of units for each dense layer. Defaults to [8].
post_loc_init_scale (float, optional): Initial scale for the location of the posterior distribution. Defaults to 0.1.
prior_scale (float, optional): Scale for the prior distribution. Defaults to 0.25.
kl_weight (float, optional): Weighting factor for the KullbackLeibler divergence. Defaults to 1e-5.
name (str, optional): Name of the encoder. Defaults to 'encoder'.
**kwargs: Additional keyword arguments.
"""
def __init__(self,
n_latent_dims: int=2,
layer_units: list=[8],
post_loc_init_scale: float=0.1,
prior_scale: float=0.25,
kl_weight: float=1e-5,
name = 'encoder',
**kwargs):
""" Initialize the RandomEffectEncoder. """
super(RandomEffectEncoder, self).__init__(n_latent_dims=n_latent_dims,
layer_units=layer_units,name=name, **kwargs)
#dictionary of random effect layers
self.re_layers = {}
#dictionary of activation layers
self.act_layers = {}
#Build blocks of (dense, RE, activation layers)
self.layer_blocks = {}
#dense blocks are inherited from Encoder class
for key, layer in self.dense_blocks.items():
#layer i
layer_i = key.split("_")[-1]
#random effect layer
self.re_layers["re_"+layer_i] = ClusterScaleBiasBlock(layer.units,
post_loc_init_scale = post_loc_init_scale,
prior_scale = prior_scale,
kl_weight = kl_weight,
name = name + '_re_'+layer_i)
#act layer
self.act_layers["act_"+layer_i] = Activation('selu')
#add blocks of (dense, RE, activation layers)
self.layer_blocks["block_"+layer_i] = (layer,self.re_layers["re_"+layer_i], self.act_layers["act_"+layer_i])
#define re_encoder_layers
self.re_encoder_layers = {**self.layer_blocks, "dense_latent": self.dense_latent}
def call(self, inputs, training=None):
"""
Forward pass for the RandomEffectEncoder.
Args:
inputs (tuple): A tuple containing two elements - the input data (x) and the random effects data (z).
training (bool, optional): If in training mode or not. Defaults to None.
Returns:
tf.Tensor: Transformed input after passing through dense, random effect and activation layers.
"""
x, z = inputs
# print("x.shape:",x.shape)
# print("z.shape:",z.shape)
for key, (dense, re, activation) in self.layer_blocks.items():
x = dense(x)
x = re((x, z), training=training)
x = activation(x)
x = self.dense_latent(x)
return x
# def summary(self):
# print("RandomEffectEncoder Summary:")
# print(f"{'Layer':<20} {'Output Shape':<20} {'# Params':<10}")
# for name, (dense, re, activation) in self.layer_blocks.items():
# # Assuming model has been built at least once so these methods can be accessed
# print(f"{dense.name:<20} {str(dense.output_shape):<20} {dense.count_params():<10}")
# print(f"{re.name:<20} {'-':<20} {re.count_params():<10}")
# print(f"{activation.name:<20} {'-':<20} {'-':<10}")
# print(f"{self.dense_latent.name:<20} {str(self.dense_latent.output_shape):<20} {self.dense_latent.count_params():<10}")
class RandomEffectDecoder(Decoder):
def __init__(self,
in_shape: tuple,
layer_units: list=[8],
last_activation: str='sigmoid',
post_loc_init_scale: float=0.1,
prior_scale: float=0.25,
kl_weight: float=1e-5,
name = 'decoder',
**kwargs):
""" Initialize the RandomEffectDecoder. """
#I do not want tied weights in the RandomEffectDecoder
super(RandomEffectDecoder, self).__init__(in_shape = in_shape,
layer_units = layer_units,
last_activation = last_activation,
name = name,
tied_weights = False,
**kwargs)
#dictionary of random effect layers
self.re_layers = {}
#dictionary of activation layers
self.act_layers = {}
#Build blocks of (dense, RE, activation layers)
self.layer_blocks = {}
#dense blocks are inherited from Encoder class
for key, layer in self.all_layers.items():
#layer i
layer_i = key.split("_")[-1]
#random effect layer
self.re_layers["re_"+layer_i] = ClusterScaleBiasBlock(layer.units,
post_loc_init_scale = post_loc_init_scale,
prior_scale = prior_scale,
kl_weight = kl_weight,
name = name + '_re_'+layer_i)
if key == 'dense_out': #for the block that has the dense_out layer, the activation layer = last_activation
self.act_layers["last_act"] = Activation(self.last_activation, name = name + '_act_'+self.last_activation)
#add blocks of (dense, RE, activation layers)
self.layer_blocks["block_"+layer_i] = (layer,self.re_layers["re_"+layer_i], self.act_layers["last_act"])
else: #all other activation layers are 'relu'
self.act_layers["act_"+layer_i] = Activation('selu', name = name + '_act_'+layer_i)
#add blocks of (dense, RE, activation layers)
self.layer_blocks["block_"+layer_i] = (layer,self.re_layers["re_"+layer_i], self.act_layers["act_"+layer_i])
self.re_decoder_layers = self.layer_blocks
def call(self, inputs, training=None):
"""
Forward pass for the RandomEffectDecoder.
Args:
inputs (tuple): A tuple containing two elements - the input data (x) and the random effects data (z).
training (bool, optional): If in training mode or not. Defaults to None.
Returns:
tf.Tensor: Transformed input after passing through dense, random effect and activation layers.
"""
x, z = inputs
for key, (dense, re, activation) in self.layer_blocks.items():
x = dense(x)
x = re((x, z), training=training)
x = activation(x)
return x
# def summary(self):
# print("RandomEffectDecoder Summary:")
# print(f"{'Layer':<20} {'Output Shape':<20} {'# Params':<10}")
# for name, (dense, re, activation) in self.layer_blocks.items():
# # Assuming model has been built at least once so these methods can be accessed
# print(f"{dense.name:<20} {str(dense.output_shape):<20} {dense.count_params():<10}")
# print(f"{re.name:<20} {'-':<20} {re.count_params():<10}")
# print(f"{activation.name:<20} {'-':<20} {'-':<10}")
class DomainEnhancingAutoencoderClassifier(tf.keras.Model):
"""
scMEDAL Random Effects subnetwork (RE)
Autoencoder model for classification and clustering of the batch effects.
This model leverages an autoencoder structure with a domain-enhanced approach to perform
classification and clustering tasks. It comprises an encoder (`RandomEffectEncoder`), a decoder
(`RandomEffectDecoder`), and a batch classifier which operates in the latent space.
The model can predict clusters or class labels based on the latent and reconstructed representations.
Parameters:
------------
- in_shape (tuple): Input shape of the data.
- n_clusters (int, optional): Number of clusters for classification. Default is 10.
- n_latent_dims (int, optional): Dimensionality of the latent space. Default is 2.
- layer_units (list, optional): Units for each layer in the autoencoder. Default is [10, 5].
- layer_units_classifier (list, optional): Units for each layer in the classifier. Default is [2].
- n_pred (int, optional): Number of prediction classes if `get_pred` is True. Default is 10.
- last_activation (str, optional): Activation for the last layer of the autoencoder. Default is "sigmoid".
- post_loc_init_scale (float, optional): Initial scale for the posterior's location. Default is 0.1.
- prior_scale (float, optional): Scale for the prior distribution. Default is 0.25.
- kl_weight (float, optional): Weight for KL divergence loss. Default is 1e-5.
- get_pred (bool, optional): Predict class labels alongside clusters. Default is False.
- get_recon_cluster (bool, optional): Retrieve cluster prediction from reconstruction. Default is False.
- name (str, optional): Model's name. Default is "ae".
Attributes:
------------
Various components of the model such as the encoder, decoder, and classifiers are stored as attributes.
Methods:
------------
- call(inputs, training=None): Performs a forward pass of the model.
- compile(...): Configures the model for training.
- train_step(data): Defines a single training step for the model.
- test_step(data): Defines a single test (or validation) step for the model.
Note:
The model is designed to handle input data as a tuple of (count matrix, clusters). If enabled (via `get_pred`),
it can also take labels for supervised training. Outputs include the reconstructed data and the predictions
based on latent and reconstructed representations.
"""
def __init__(self,
in_shape: tuple,
n_clusters: int=10,
n_latent_dims: int = 2,
layer_units: list = [10,5],
layer_units_classifier:list = [2],
n_pred: int = 10,
last_activation: str = "sigmoid",
post_loc_init_scale: float=0.1,
prior_scale: float=0.25,
kl_weight: float=1e-5,
get_pred = False,
get_recon_cluster = False,
name='ae',
**kwargs):
super(DomainEnhancingAutoencoderClassifier, self).__init__(name=name, **kwargs)
self.in_shape = in_shape
self.n_clusters = n_clusters
self.n_latent_dims = n_latent_dims
self.layer_units = layer_units
self.last_activation = last_activation
self.get_pred = get_pred
self.n_pred = n_pred
self.layer_units_classifier = layer_units_classifier
self.get_recon_cluster = get_recon_cluster
# RE encoder
self.re_encoder = RandomEffectEncoder(n_latent_dims=self.n_latent_dims,
layer_units=self.layer_units,
post_loc_init_scale=post_loc_init_scale,
prior_scale=prior_scale,
kl_weight=kl_weight)
# RE decoder: weights not tied
self.re_decoder = RandomEffectDecoder(in_shape=self.in_shape,
layer_units = self.layer_units,
last_activation = self.last_activation,
post_loc_init_scale=post_loc_init_scale,
prior_scale=prior_scale,
kl_weight=kl_weight)
# The latent classifier returns class predictions in addition to cluster predictions if get_pred =True
self.re_latent_classifier = Classifier(n_clusters=self.n_clusters,layer_units = self.layer_units_classifier,n_pred = self.n_pred, get_pred = self.get_pred)
# get cluster prediction from reconstruction
if self.get_recon_cluster:
self.re_recon_classifier = Classifier(n_clusters=self.n_clusters,layer_units = self.layer_units_classifier, get_pred = False)
def call(self, inputs, training=None):
if len(inputs) != 2:
raise ValueError('Model inputs need to be a tuple of (count matrix, clusters)')
x, z = inputs
# Encode inputs
latent = self.re_encoder((x, z), training=training)
# Reconstruct image from latents
recon = self.re_decoder((latent, z), training=training)
output_dict = {'recon': recon}
# Apply latent classifier
latent_outs = self.re_latent_classifier(latent)
if self.get_pred:
# The latent classifier returns class predictions in addition to cluster predictions if get_pred=True
pred_y, pred_c_latent = latent_outs
output_dict['pred_y'] = pred_y
output_dict['pred_c_latent'] = pred_c_latent
else:
pred_c_latent = latent_outs
output_dict['pred_c_latent'] = pred_c_latent
if self.get_recon_cluster:
# Cluster predictions from reconstructed counts
pred_c_recon = self.re_recon_classifier(recon)
output_dict['pred_c_recon'] = pred_c_recon
return output_dict
def compile(self,
loss_recon=tf.keras.losses.MeanSquaredError(),
loss_multiclass=tf.keras.losses.CategoricalCrossentropy(),
metric_multiclass=tf.keras.metrics.CategoricalAccuracy(name='categorical_accuracy'),
optimizer=tf.keras.optimizers.Adam(lr=0.0001),
loss_recon_weight=1.0,
loss_class_weight=0.01,
loss_latent_cluster_weight=0.001,
loss_recon_cluster_weight=0.001):
super().compile()
self.loss_recon = loss_recon
# the loss multiclass will be used for multiclass classification (cluster, class pred, etc)
self.loss_multiclass = loss_multiclass
self.optimizer = optimizer
# loss weights
self.loss_latent_cluster_weight = loss_latent_cluster_weight
self.loss_recon_weight = loss_recon_weight
# Loss trackers (mean loss across all the batches)
self.loss_recon_tracker = tf.keras.metrics.Mean(name='recon_loss')
self.loss_latent_cluster_tracker = tf.keras.metrics.Mean(name='la_clus_loss')
self.loss_kl_tracker = tf.keras.metrics.Mean(name='kld')
self.loss_total_tracker = tf.keras.metrics.Mean(name='total_loss')
if self.get_pred:
self.metric_multiclass = metric_multiclass
self.loss_class_weight = loss_class_weight
self.loss_class_tracker = tf.keras.metrics.Mean(name='class_loss')
if self.get_recon_cluster:
self.loss_recon_cluster_weight = loss_recon_cluster_weight
self.loss_recon_cluster_tracker = tf.keras.metrics.Mean(name='recon_clus_loss')
@property
def metrics(self):
metrics_list = [self.loss_recon_tracker,
self.loss_latent_cluster_tracker,
self.loss_kl_tracker,
self.loss_total_tracker]
if self.get_pred:
metrics_list = metrics_list +[self.loss_class_tracker,
self.metric_multiclass]
elif self.get_recon_cluster:
metrics_list = metrics_list +[self.loss_recon_cluster_tracker]
return metrics_list
def _compute_update_loss(self, loss_recon, loss_latent_cluster, loss_recon_cluster=None,loss_class = None,
training=True):
'''Compute total loss and update loss running means'''
#update loss
if (self.get_pred)&(loss_class is not None):
self.loss_class_tracker.update_state(loss_class)
if (self.get_recon_cluster)&(loss_recon_cluster is not None):
self.loss_recon_cluster_tracker.update_state(loss_recon_cluster)
self.loss_recon_tracker.update_state(loss_recon)
self.loss_latent_cluster_tracker.update_state(loss_latent_cluster)
if training:
# The encoder and decoder have RandomEffect Layers, which inherit the properties of tpl.DenseVariational.
# This layer adds the kld as regularization loss to the model. The regularizations are stored in model.losses.
# Since there are more than one RElayers, we get the mean of all of them.
kld = tf.reduce_mean(self.re_encoder.losses) + tf.reduce_mean(self.re_decoder.losses)
self.loss_kl_tracker.update_state(kld)
else:
# KLD can't be computed at inference time because posteriors are simplified to
# point estimates
kld = 0
loss_total = (self.loss_recon_weight*loss_recon) + (self.loss_latent_cluster_weight * loss_latent_cluster)+kld
if (self.get_pred)&(loss_class is not None):
loss_total = loss_total + (self.loss_class_weight * loss_class)
if (self.get_recon_cluster)&(loss_recon_cluster is not None):
loss_total = loss_total + (self.loss_recon_cluster_weight * loss_recon_cluster)
self.loss_total_tracker.update_state(loss_total)
return loss_total
def train_step(self, data):
#missing to edit this part
#load data
x, clusters = data[0]
if self.get_pred:
_, labels = data[1]
sample_weights = None if len(data) != 3 else data[2]
# Train the rest of the model
with tf.GradientTape() as gt:
# Apply RE autoencoder: encoder + decoder
outputs = self((x, clusters), training=True)
recon = outputs['recon']
pred_c_latent = outputs['pred_c_latent']
if self.get_pred:
pred_y = outputs['pred_y']
# Multiclass loss
loss_class = self.loss_multiclass(labels, pred_y)
else:
loss_class = None
if self.get_recon_cluster:
pred_c_recon_1 = outputs['pred_c_recon']
loss_recon_cluster_1 = self.loss_multiclass(clusters, pred_c_recon_1)
else:
loss_recon_cluster_1 = None
# mse loss
loss_recon = self.loss_recon(x, recon)
loss_latent_cluster = self.loss_multiclass(clusters, pred_c_latent)
loss_total = self._compute_update_loss(loss_recon = loss_recon,
loss_latent_cluster = loss_latent_cluster,
loss_recon_cluster = loss_recon_cluster_1,
loss_class = loss_class)
# get trainable variables
lsWeights = self.re_encoder.trainable_variables + self.re_decoder.trainable_variables
# if the weight of loss_latent_cluster_weight>0, add it to the trainable variables
if self.loss_latent_cluster_weight>0:
lsWeights = lsWeights + self.re_latent_classifier.trainable_variables
if self.get_recon_cluster:
lsWeights = lsWeights + self.re_recon_classifier.trainable_variables
# backpropagate
grads = gt.gradient(loss_total, lsWeights)
self.optimizer.apply_gradients(zip(grads, lsWeights))
if self.get_pred:
# Update metrics
self.metric_multiclass.update_state(labels, pred_y)
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
#load data
x, clusters = data[0]
if self.get_pred:
_, labels = data[1]
sample_weights = None if len(data) != 3 else data[2]
outputs = self((x, clusters), training=False)
recon = outputs['recon']
pred_c_latent = outputs['pred_c_latent']
if self.get_pred:
pred_y = outputs['pred_y']
# Multiclass loss
loss_class = self.loss_multiclass(labels, pred_y)
else:
loss_class = None
if self.get_recon_cluster:
pred_c_recon_1 = outputs['pred_c_recon']
loss_recon_cluster_1 = self.loss_multiclass(clusters, pred_c_recon_1)
else:
loss_recon_cluster_1 = None
loss_recon = self.loss_recon(x, recon)
loss_latent_cluster = self.loss_multiclass(clusters, pred_c_latent)
loss_total = self._compute_update_loss(loss_recon = loss_recon,
loss_latent_cluster = loss_latent_cluster,
loss_recon_cluster = loss_recon_cluster_1,
loss_class = loss_class, training=False)
if self.get_pred:
# Update metrics
self.metric_multiclass.update_state(labels, pred_y)
return {m.name: m.result() for m in self.metrics}
fe_latent, re_latent,z = inputs
x = self.concat2subnets([fe_latent, re_latent])
# apply hidden layers
for key, layer in self.dense_hidden_layers.items():
x = layer(x)
# I will take me_latent after apply a dense layer to me_latent = fe_latent +re_latent. However, this layer may be optional
me_latent = self.dense_me_latent(x)
# me_pred_y is always done
me_outputs = self.me_classifier((me_latent,z))
return me_outputs
class MixedEffectsEncoder(tf.keras.layers.Layer):
"""
A TensorFlow Keras layer that concatenates fixed effects (FE) and random effects (RE) latent spaces,
and optionally applies a random effects (RE) layer. The resulting mixed effects latent space is
then processed through a series of dense hidden layers.
Parameters:
- n_latent_dims (int): The number of dimensions in the mixed effects latent space.
- layer_units (list of int): The number of units in each dense hidden layer.
- post_loc_init_scale (float): Initial scale for the location in the post-RE layer,
used if an RE layer is added.
- prior_scale (float): Scale of the prior in the RE layer, used if an RE layer is added.
- kl_weight (float): Weight of the KL divergence in the loss, used if an RE layer is added.
- add_re_2_meclass (bool): Determines whether to add an RE layer to the Mixed Effects Classifier.
- name (str): Name of the layer.
- **kwargs: Additional keyword arguments for the base Layer class.
This encoder first concatenates the FE and RE latent spaces. It then processes the concatenated latent
space through a series of dense hidden layers defined in `layer_units`. If `add_re_2_meclass` is True,
an RE layer is applied after the dense hidden layers. The output is a mixed effects latent space that
can be used for further processing or classification.
The `call` method:
Takes inputs `fe_latent`, `re_latent`, and `z`, and processes them through the encoder to produce the
mixed effects latent space. If `add_re_2_meclass` is True, `z` is used in the RE layer.
Inputs:
- fe_latent: The latent representation of the fixed effects.
- re_latent: The latent representation of the random effects.
- z: Additional features or information, used if an RE layer is added.
Returns:
- me_latent: The resulting mixed effects latent space.
"""
def __init__(self,
n_latent_dims: int = 2,
layer_units: list = [10,5],
post_loc_init_scale: float=0.1,
prior_scale: float=0.25,
kl_weight: float=1e-5,
add_re_2_meclass = False,
name='me_encoder',
**kwargs):
super(MixedEffectsEncoder, self).__init__(name=name, **kwargs)
self.n_latent_dims = n_latent_dims
self.layer_units = layer_units
#add RE layer to ME classifier
self.add_re_2_meclass = add_re_2_meclass
self.concat2subnets = tf.keras.layers.Concatenate(axis=-1, name=name + 'concat_fe_re_latent')
# define hidden layers
self.dense_hidden_layers = {}
for i, n_units in enumerate(self.layer_units):
key_name = "dense_" + str(i)
self.dense_hidden_layers[key_name] = Dense(units=n_units, activation="selu", name=key_name)
if self.add_re_2_meclass:
self.re_layer = ClusterScaleBiasBlock(self.n_latent_dims, post_loc_init_scale = post_loc_init_scale,
prior_scale = prior_scale,
kl_weight = kl_weight,
name = name + '_re_layer')
self.act = Activation('selu')
self.dense_me_latent = Dense(units=self.n_latent_dims, activation="selu", name="dense_me_latent")
def call(self, inputs, training=None):
# fe_latent, re_latent,z = inputs
# fe_latent = inputs["fe_latent"]
# re_latent = inputs["re_latent"]
fe_latent = inputs.get("fe_latent")
re_latent = inputs.get("re_latent", None)
# Only concatenate re_latent if it is not None
if re_latent is not None:
x = self.concat2subnets([fe_latent, re_latent])
else:
x = fe_latent
if self.add_re_2_meclass:
z = inputs["z"]
# x = self.concat2subnets([fe_latent, re_latent])
# apply hidden layers
for key, layer in self.dense_hidden_layers.items():
x = layer(x)
# Optional, add re layer
if self.add_re_2_meclass:
x = self.re_layer((x, z), training=training)
x = self.act(x)
me_latent = self.dense_me_latent(x)
return me_latent
class MixedEffectsModel(tf.keras.Model):
"""
MixedEffectsModel. It is a mixed effects classifier which processes inputs through a Mixed Effects Encoder
and a dense output layer for classification. It's designed to handle both fixed effects (FE)
and random effects (RE) latent spaces, making it suitable for scenarios where
both fixed and random effects are considered.
Parameters:
- n_latent_dims (int): The number of dimensions in the mixed effects latent space
created by the Mixed Effects Encoder.
- layer_units (list of int): The number of units in each dense hidden layer within
the Mixed Effects Encoder.
- n_pred (int): The number of units in the final dense output layer, typically corresponding
to the number of classes for classification.
- post_loc_init_scale (float): Initial scale for the location in the post-RE layer within
the Mixed Effects Encoder, used if an RE layer is added.
- prior_scale (float): Scale of the prior in the RE layer within the Mixed Effects Encoder,
used if an RE layer is added.
- kl_weight (float): Weight of the KL divergence in the loss within the Mixed Effects Encoder,
used if an RE layer is added.
- add_re_2_meclass (bool): Determines whether to add an RE layer to the Mixed Effects Classifier
within the Mixed Effects Encoder.
- name (str): Name of the model.
- **kwargs: Additional keyword arguments for the base Model class.
The model encapsulates a Mixed Effects Encoder for processing the FE and RE latent spaces,
followed by a dense output layer with softmax activation for classification.
The `call` method:
Processes the inputs through the Mixed Effects Encoder and then through the dense output layer.
Inputs:
- fe_latent: The latent representation of the fixed effects.
- re_latent: The latent representation of the random effects.
- z: Additional features or information, used in the RE layer of the Mixed Effects Encoder
if `add_re_2_meclass` is True.
Returns:
- y: The classification output, with probabilities for each class.
"""
def __init__(self,
n_latent_dims: int = 2,
layer_units: list = [10,5],
n_pred: int = 10,
post_loc_init_scale: float=0.1,
prior_scale: float=0.25,
kl_weight: float=1e-5,
add_re_2_meclass = False,
name='mec',
**kwargs):
super(MixedEffectsModel, self).__init__(**kwargs)
self.n_latent_dims = n_latent_dims
self.layer_units = layer_units
self.n_pred = n_pred
# MixedEffectsmodule
self.encoder = MixedEffectsEncoder(n_latent_dims = self.n_latent_dims,
layer_units = self.layer_units,
post_loc_init_scale = post_loc_init_scale,
prior_scale = prior_scale,
kl_weight = kl_weight,
add_re_2_meclass = add_re_2_meclass,
name = 'me_encoder',
**kwargs)
self.dense_out = Dense(self.n_pred, activation='softmax', name=name + '_out')
def call(self, inputs,training=None):
# fe_latent, re_latent,z = inputs
me_latent = self.encoder(inputs,training=training)
# dense out with softmax activation
y = self.dense_out(me_latent)
return y