''' Autoencoder-classifiers, including domain adversarial and mixed effects variations. ''' import numpy as np import tensorflow as tf import tensorflow.keras.layers as tkl from tensorflow.python.framework import tensor_shape from tensorflow.python.keras.engine.input_spec import InputSpec from .random_effects import ClusterScaleBiasBlock, RandomEffects class TiedConv2DTranspose(tkl.Conv2DTranspose): def __init__(self, source_layer: tkl.Conv2D, filters, kernel_size, strides=(1, 1), padding='valid', output_padding=None, data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs): """Conv2DTranspose layer that shares weights with a given Conv2D layer. (The bias tensor is not shared as the dimensionality of the output is inherently different.) Args: source_layer (Conv2D): Conv2D layer with which to share weights all other arguments same as original Conv2DTranspose """ self.source_layer = source_layer super().__init__(filters, kernel_size, strides=strides, padding=padding, output_padding=output_padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint, **kwargs) def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if len(input_shape) != 4: raise ValueError('Inputs should have rank 4. Received input ' 'shape: ' + str(input_shape)) channel_axis = self._get_channel_axis() if input_shape.dims[channel_axis].value is None: raise ValueError('The channel dimension of the inputs ' 'should be defined. Found `None`.') input_dim = int(input_shape[channel_axis]) self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim}) # kernel_shape = self.kernel_size + (self.filters, input_dim) # Link to weights from the source conv layer self.kernel = self.source_layer.weights[0] if self.use_bias: self.bias = self.add_weight(name='bias', shape=(self.filters,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, trainable=True, dtype=self.dtype) else: self.bias = None self.built = True class Encoder(tkl.Layer): def __init__(self, n_latent_dims: int=56, layer_filters: list=[64, 64, 64, 128, 256, 512], return_layer_activations: bool=False, name='encoder', **kwargs): """Transforms 2D image into a compressed vector representation. Contains 6x 2D strided convolutional layers. Args: n_latent_dims (int, optional): Size of compressed representation output. Defaults to 56. layer_filters (list, optional): Filters per convolutional layer. Defaults to [64, 128, 256, 512, 1024, 1024]. return_layer_activations (bool, optional): Whether to return every layer's output. Defaults to False. name (str, optional): Name. Defaults to 'encoder'. """ super(Encoder, self).__init__(name=name, **kwargs) self.n_latent_dims = n_latent_dims self.layer_filters = layer_filters self.return_layer_activations = return_layer_activations self.conv0 = tkl.Conv2D(layer_filters[0], 4, padding='same', name=name + '_conv0') self.bn0 = tkl.BatchNormalization(name=name+ '_bn0') self.prelu0 = tkl.PReLU(name=name + '_prelu0') self.conv1 = tkl.Conv2D(layer_filters[1], 4, padding='same', name=name + '_conv1') self.bn1 = tkl.BatchNormalization(name=name+ '_bn1') self.prelu1 = tkl.PReLU(name=name + '_prelu1') self.conv2 = tkl.Conv2D(layer_filters[2], 4, padding='same', name=name + '_conv2') self.bn2 = tkl.BatchNormalization(name=name+ '_bn2') self.prelu2 = tkl.PReLU(name=name + '_prelu2') self.conv3 = tkl.Conv2D(layer_filters[3], 4, strides=(2, 2), padding='same', name=name + '_conv3') self.bn3 = tkl.BatchNormalization(name=name+ '_bn3') self.prelu3 = tkl.PReLU(name=name + '_prelu3') self.conv4 = tkl.Conv2D(layer_filters[4], 4, strides=(2, 2), padding='same', name=name + '_conv4') self.bn4 = tkl.BatchNormalization(name=name+ '_bn4') self.prelu4 = tkl.PReLU(name=name + '_prelu4') self.conv5 = tkl.Conv2D(layer_filters[5], 4, strides=(2, 2), padding='same', name=name + '_conv5') self.bn5 = tkl.BatchNormalization(name=name+ '_bn5') self.prelu5 = tkl.PReLU(name=name + '_prelu5') self.flatten = tkl.Flatten(name=name + '_flatten') self.dense = tkl.Dense(n_latent_dims, name=name + '_latent') self.bn_out = tkl.BatchNormalization(name=name + '_output') def call(self, inputs, training=None): x0 = self.conv0(inputs) x0 = self.bn0(x0, training=training) x0 = self.prelu0(x0) x1 = self.conv1(x0) x1 = self.bn1(x1, training=training) x1 = self.prelu1(x1) x2 = self.conv2(x1) x2 = self.bn2(x2, training=training) x2 = self.prelu2(x2) x3 = self.conv3(x2) x3 = self.bn3(x3, training=training) x3 = self.prelu3(x3) x4 = self.conv4(x3) x4 = self.bn4(x4, training=training) x4 = self.prelu4(x4) x5 = self.conv5(x4) x5 = self.bn5(x5, training=training) x5 = self.prelu5(x5) latent = self.flatten(x5) latent = self.dense(latent) latent = self.bn_out(latent) if self.return_layer_activations: return x0, x1, x2, x3, x4, x5, latent else: return latent def get_config(self): return {'n_latent_dims': self.n_latent_dims, 'layer_filters': self.layer_filters, 'return_layer_activations': self.return_layer_activations} class Decoder(tkl.Layer): def __init__(self, image_shape: tuple=(256, 256, 1), layer_filters: list=[512, 256, 128, 64, 64, 64], name='decoder', **kwargs): """Transforms compressed vector representation back into a 2D image. Contains 6x 2D transposed convolutional layers. Args: image_shape (tuple, optional): Output image shape. Defaults to (256, 256, 1). layer_filters (list, optional): Number of filters in each convolutional layer. This should be the reverse of the layer_filters argument given to the encoder. Defaults to [1024, 1024, 512, 256, 128, 64]. name (str, optional): Name. Defaults to 'decoder'. """ super(Decoder, self).__init__(name=name, **kwargs) self.image_shape = image_shape self.layer_filters = layer_filters tupReshape = (image_shape[0] // 8, image_shape[1] // 8, layer_filters[0]) self.dense = tkl.Dense(np.product(tupReshape), name=name + '_dense') self.reshape = tkl.Reshape(tupReshape, name=name + '_reshape') self.prelu_dense = tkl.PReLU(name=name + '_prelu_dense') self.tconv0 = tkl.Conv2DTranspose(layer_filters[1], 4, strides=(2, 2), padding='same', name=name + '_tconv0') self.bn0 = tkl.BatchNormalization(name=name+ '_bn0') self.prelu0 = tkl.PReLU(name=name + '_prelu0') self.tconv1 = tkl.Conv2DTranspose(layer_filters[2], 4, strides=(2, 2), padding='same', name=name + '_tconv1') self.bn1 = tkl.BatchNormalization(name=name+ '_bn1') self.prelu1 = tkl.PReLU(name=name + '_prelu1') self.tconv2 = tkl.Conv2DTranspose(layer_filters[3], 4, strides=(2, 2), padding='same', name=name + '_tconv2') self.bn2 = tkl.BatchNormalization(name=name+ '_bn2') self.prelu2 = tkl.PReLU(name=name + '_prelu2') self.tconv3 = tkl.Conv2DTranspose(layer_filters[4], 4, padding='same', name=name + '_tconv3') self.bn3 = tkl.BatchNormalization(name=name+ '_bn3') self.prelu3 = tkl.PReLU(name=name + '_prelu3') self.tconv4 = tkl.Conv2DTranspose(layer_filters[5], 4, padding='same', name=name + '_tconv4') self.bn4 = tkl.BatchNormalization(name=name+ '_bn4') self.prelu4 = tkl.PReLU(name=name + '_prelu4') self.tconv5 = tkl.Conv2DTranspose(1, 4, padding='same', name=name + '_tconv5') self.bn5 = tkl.BatchNormalization(name=name+ '_bn5') self.sigmoid_out = tkl.Activation('sigmoid', name=name + '_sigmoid') def call(self, inputs, training=None): x = self.dense(inputs) x = self.reshape(x) x = self.prelu_dense(x) x = self.tconv0(x) x = self.bn0(x, training=training) x = self.prelu0(x) x = self.tconv1(x) x = self.bn1(x, training=training) x = self.prelu1(x) x = self.tconv2(x) x = self.bn2(x, training=training) x = self.prelu2(x) x = self.tconv3(x) x = self.bn3(x, training=training) x = self.prelu3(x) x = self.tconv4(x) x = self.bn4(x, training=training) x = self.prelu4(x) x = self.tconv5(x) x = self.bn5(x, training=training) x = self.sigmoid_out(x) return x def get_config(self): return {'image_shape': self.image_shape, 'layer_filters': self.layer_filters} class TiedDecoder(Decoder): def __init__(self, encoder_layers: list, image_shape: tuple=(256, 256, 1), layer_filters: list=[512, 256, 128, 64, 64, 64], name='decoder', **kwargs): """Transforms compressed vector representation back into a 2D image. Contains 6x 2D transposed convolutional layers, and filter weights are tied to a given encoder. Args: encoder_layers (list): List of encoder layers whose weights will be shared with this decoder. image_shape (tuple, optional): Output image shape. Defaults to (256, 256, 1). layer_filters (list, optional): Number of filters in each convolutional layer. This should be the reverse of the layer_filters argument given to the encoder. Defaults to [1024, 1024, 512, 256, 128, 64]. name (str, optional): Name. Defaults to 'decoder'. """ super(TiedDecoder, self).__init__(image_shape=image_shape, layer_filters=layer_filters, name=name, **kwargs) # Replace conventional Conv2DTranspose layers with ones that share weights self.tconv0 = TiedConv2DTranspose(encoder_layers[-1], layer_filters[1], 4, strides=(2, 2), padding='same', name=name + '_tconv0') self.tconv1 = TiedConv2DTranspose(encoder_layers[-2], layer_filters[2], 4, strides=(2, 2), padding='same', name=name + '_tconv1') self.tconv2 = TiedConv2DTranspose(encoder_layers[-3], layer_filters[3], 4, strides=(2, 2), padding='same', name=name + '_tconv2') self.tconv3 = TiedConv2DTranspose(encoder_layers[-4], layer_filters[4], 4, padding='same', name=name + '_tconv3') self.tconv4 = TiedConv2DTranspose(encoder_layers[-5], layer_filters[5], 4, padding='same', name=name + '_tconv4') self.tconv5 = TiedConv2DTranspose(encoder_layers[-6], 1, 4, padding='same', name=name + '_tconv5') class AuxClassifier(tkl.Layer): def __init__(self, units: int=32, name='auxclassifier', **kwargs): """Simple dense binary classifier with one hidden layer and sigmoid output. Intended to be attached to the autoencoder to perform classification. Args: units (int, optional): Number of hidden layer neurons. Defaults to 32. name (str, optional): Name. Defaults to 'auxclassifier'. """ super(AuxClassifier, self).__init__(name=name, **kwargs) self.units = units self.hidden = tkl.Dense(units, name=name + '_dense') self.activation = tkl.LeakyReLU(name=name + '_leakyrelu') self.dense_out = tkl.Dense(1, activation='sigmoid', name=name + '_output') def call(self, inputs): x = self.hidden(inputs) x = self.activation(x) x = self.dense_out(x) return x def get_config(self): return {'units': self.units} class BaseAutoencoderClassifier(tf.keras.Model): def __init__(self, image_shape: tuple=(256, 256, 1), n_latent_dims: int=56, encoder_layer_filters: list=[64, 64, 64, 128, 256, 512], classifier_hidden_units: int=32, name='autoencoder', **kwargs): """Basic autoencoder with auxiliary classifier to predict a binary label from the latent representation. Args: image_shape (tuple, optional): Input image shape. Defaults to (256, 256, 1). n_latent_dims (int, optional): Size of latent representation. Defaults to 56. encoder_layer_filters (list, optional): Number of filters per encoder layer. Defaults to [64, 128, 256, 512, 1024, 1024]. classifier_hidden_units (int, optional): Number of hidden layer neurons in the auxiliary classifier. Defaults to 32. name (str, optional): Name. Defaults to 'autoencoder'. """ super(BaseAutoencoderClassifier, self).__init__(name=name, **kwargs) self.image_shape = image_shape self.n_latent_dims = n_latent_dims self.encoder_layer_filters = encoder_layer_filters self.decoder_layer_filters = encoder_layer_filters[-1::-1] self.classifier_hidden_units = classifier_hidden_units self.encoder = Encoder(n_latent_dims=n_latent_dims, layer_filters=encoder_layer_filters, return_layer_activations=False) lsEncoderLayers = [self.encoder.conv0, self.encoder.conv1, self.encoder.conv2, self.encoder.conv3, self.encoder.conv4, self.encoder.conv5] self.decoder = TiedDecoder(lsEncoderLayers, image_shape=image_shape, layer_filters=self.decoder_layer_filters) self.classifier = AuxClassifier(units=classifier_hidden_units) def call(self, inputs, training=None): latent = self.encoder(inputs, training=training) recon = self.decoder(latent, training=training) classification = self.classifier(latent) return recon, classification class AdversarialClassifier(tkl.Layer): def __init__(self, image_shape: tuple, n_clusters: int, layer_filters: list=[16, 32, 32, 64, 64, 128, 128, 256], dense_units: int=512, name='adversary', **kwargs): """Domain adversarial classifier for predicting a sample's cluster from the layer outputs of the Encoder. Args: image_shape (tuple): Original image shape. n_clusters (int): Number of possible clusters (domains), i.e. the size of the softmax output. layer_filters (list, optional): Number of filters in each adversary layer. Defaults to [16, 32, 64, 128, 256, 512]. dense_units (int, optional): Number of neurons in adversary dense layer. Defaults to 512. name (str, optional): Name. Defaults to 'adversary'. """ super(AdversarialClassifier, self).__init__(name=name, **kwargs) if image_shape[:2] != (256, 256): raise ValueError('Only 256x256 images are supported at this time.') self.image_shape = image_shape self.n_clusters = n_clusters self.layer_filters = layer_filters self.dense_units = dense_units self.conv0 = tkl.Conv2D(layer_filters[0], 4, strides=(2, 2), padding='same', name=name + '_conv0') self.bn0 = tkl.BatchNormalization(name=name + '_bn0') self.prelu0 = tkl.PReLU(name=name + '_prelu0') self.concat1 = tkl.Concatenate(axis=-1, name=name + '_concat1') self.conv1 = tkl.Conv2D(layer_filters[1], 4, strides=(2, 2), padding='same', name=name + '_conv1') self.bn1 = tkl.BatchNormalization(name=name + '_bn1') self.prelu1 = tkl.PReLU(name=name + '_prelu1') self.concat2 = tkl.Concatenate(axis=-1, name=name + '_concat2') self.conv2 = tkl.Conv2D(layer_filters[2], 4, strides=(2, 2), padding='same', name=name + '_conv2') self.bn2 = tkl.BatchNormalization(name=name + '_bn2') self.prelu2 = tkl.PReLU(name=name + '_prelu2') self.concat3 = tkl.Concatenate(axis=-1, name=name + '_concat3') self.conv3 = tkl.Conv2D(layer_filters[3], 4, strides=(2, 2), padding='same', name=name + '_conv3') self.bn3 = tkl.BatchNormalization(name=name + '_bn3') self.prelu3 = tkl.PReLU(name=name + '_prelu3') self.conv4 = tkl.Conv2D(layer_filters[4], 4, strides=(2, 2), padding='same', name=name + '_conv4') self.bn4 = tkl.BatchNormalization(name=name + '_bn4') self.prelu4 = tkl.PReLU(name=name + '_prelu4') self.conv5 = tkl.Conv2D(layer_filters[5], 4, strides=(2, 2), padding='same', name=name + '_conv5') self.bn5 = tkl.BatchNormalization(name=name + '_bn5') self.prelu5 = tkl.PReLU(name=name + '_prelu5') self.flatten = tkl.Flatten(name=name + '_flatten') self.dense = tkl.Dense(units=self.dense_units, name=name + '_dense') self.softmax = tkl.Dense(units=n_clusters, activation='softmax', name=name + '_softmax') def call(self, inputs): act0, act1, act2, act3, act4, act5, latents = inputs # Layer outputs from the encoder are fed into this model at the point # where their shape matches the data shape. # First 3 encoder layers don't use strided conv and have the same output # shape as the original image. Concatenate them and feed them into the # first layer at the same time. x = tf.concat([act0, act1, act2], axis=-1) # -> D x D x = self.conv0(x) # -> D/2 x D/2 x = self.bn0(x) x = self.prelu0(x) x = self.concat1([x, act3]) x = self.conv1(x) # -> D/4 x D/4 x = self.bn1(x) x = self.prelu1(x) x = self.concat2([x, act4]) x = self.conv2(x) # -> D/8 x D/8 x = self.bn2(x) x = self.prelu2(x) x = self.concat3([x, act5]) x = self.conv3(x) # -> D/16 x D/16 x = self.bn3(x) x = self.prelu3(x) x = self.conv4(x) # -> D/32 x D/32 x = self.bn4(x) x = self.prelu4(x) x = self.conv5(x) # -> D/64 x D/64 x = self.bn5(x) x = self.prelu5(x) x = self.flatten(x) x = tf.concat([x, latents], axis=-1) x = self.dense(x) x = self.softmax(x) return x def get_config(self): return {'image_shape': self.image_shape, 'n_clusters': self.n_clusters, 'layer_filters': self.layer_filters, 'dense_units': self.dense_units} class DomainAdversarialAEC(BaseAutoencoderClassifier): def __init__(self, image_shape: tuple=(256, 256, 1), n_clusters: int=10, n_latent_dims: int=56, encoder_layer_filters: list=[64, 64, 64, 128, 256, 512], classifier_hidden_units: int=32, adversary_layer_filters: list=[16, 32, 32, 64, 64, 128, 128, 256], name='autoencoder', **kwargs ): """Domain adversarial autoencoder-classifier. Adds an adversarial classifier to predict the cluster/domain membership of each sample based on the encoder's intermediate outputs. This compels the encoder to learn features unassociated with cluster characteristics. Args: image_shape (tuple, optional): Image shape. Defaults to (256, 256, 1). n_clusters (int, optional): Number of clusters. Defaults to 10. n_latent_dims (int, optional): Size of latent representation. Defaults to 56. encoder_layer_filters (list, optional): Number of filters in each encoder layer. Defaults to [64, 64, 64, 128, 256, 512]. classifier_hidden_units (int, optional): Number of neurons in auxiliary classifier hidden layer. Defaults to 32. adversary_layer_filters (list, optional): Number of filters in each adversary layer. Defaults to [16, 32, 32, 64, 64, 128, 128, 256]. name (str, optional): Model name. Defaults to 'autoencoder'. """ super(BaseAutoencoderClassifier, self).__init__( name=name, **kwargs) self.image_shape = image_shape self.n_latent_dims = n_latent_dims self.encoder_layer_filters = encoder_layer_filters self.decoder_layer_filters = encoder_layer_filters[-1::-1] self.classifier_hidden_units = classifier_hidden_units self.encoder = Encoder(n_latent_dims=n_latent_dims, layer_filters=encoder_layer_filters, return_layer_activations=True) lsEncoderLayers = [self.encoder.conv0, self.encoder.conv1, self.encoder.conv2, self.encoder.conv3, self.encoder.conv4, self.encoder.conv5] self.decoder = TiedDecoder(lsEncoderLayers, image_shape=image_shape, layer_filters=self.decoder_layer_filters) self.classifier = AuxClassifier(units=classifier_hidden_units) self.adversary = AdversarialClassifier(image_shape, n_clusters, layer_filters=adversary_layer_filters) def call(self, inputs, training=None): images, clusters = inputs # Call encoder and get layer outputs encoder_outs = self.encoder(images, training=training) latent = encoder_outs[-1] # Reconstruct image from latents recon = self.decoder(latent, training=training) # Classify image from latents classification = self.classifier(latent) # Predict cluster from encoder layer outputs pred_cluster = self.adversary(encoder_outs) return (recon, classification, pred_cluster) def compile(self, loss_recon=tf.keras.losses.MeanSquaredError(), loss_class=tf.keras.losses.BinaryCrossentropy(), loss_adv=tf.keras.losses.CategoricalCrossentropy(), metric_class=tf.keras.metrics.AUC(name='auroc'), metric_adv=tf.keras.metrics.CategoricalAccuracy(name='acc'), opt_autoencoder=tf.keras.optimizers.Adam(lr=0.0001), opt_adversary=tf.keras.optimizers.Adam(lr=0.0001), loss_recon_weight=1.0, loss_class_weight=0.01, loss_gen_weight=0.05, ): """Compile model with given losses, metrics, and optimizers. The main autoencoder-classifier is trained with this loss function: loss_recon_weight * loss_recon (image reconstruction loss) + loss_class_weight * loss_class (phenotype classification loss) - loss_gen_weight * loss_adv (generalization (adversarial) loss) While the adversarial classifier is trained with loss_adv. Args: loss_recon (loss, optional): Image reconstruction loss. Defaults to tf.keras.losses.MeanSquaredError(). loss_class (loss, optional): Auxiliary classification loss. Defaults to tf.keras.losses.BinaryCrossentropy(). loss_adv (loss, optional): Adversarial classification loss. Defaults to tf.keras.losses.CategoricalCrossentropy(). metric_class (metric, optional): Auxiliary classification metric. Defaults to tf.keras.metrics.AUC(name='auroc'). metric_adv (metric, optional): Adversarial classification metric. Defaults to tf.keras.metrics.CategoricalAccuracy(name='acc'). opt_autoencoder (optimizer, optional): Optimizer for the main model. Defaults to tf.keras.optimizers.Adam(lr=0.0001). opt_adversary (optimizer, optional): Optimizer for the adversarial classifier. Defaults to tf.keras.optimizers.Adam(lr=0.0001). loss_recon_weight (float, optional): Weight for reconstruction loss. Defaults to 1.0. loss_class_weight (float, optional): Weight for auxiliary classification loss. Defaults to 0.01. loss_gen_weight (float, optional): Weight for generalization loss. Defaults to 0.05. """ super().compile() self.loss_recon = loss_recon self.loss_class = loss_class self.loss_adv = loss_adv self.opt_autoencoder = opt_autoencoder self.opt_adversary = opt_adversary # Loss trackers to maintain a running mean of each loss self.loss_recon_tracker = tf.keras.metrics.Mean(name='recon_loss') self.loss_class_tracker = tf.keras.metrics.Mean(name='class_loss') self.loss_adv_tracker = tf.keras.metrics.Mean(name='adv_loss') self.loss_total_tracker = tf.keras.metrics.Mean(name='total_loss') self.metric_class = metric_class self.metric_adv = metric_adv self.loss_recon_weight = loss_recon_weight self.loss_class_weight = loss_class_weight self.loss_gen_weight = loss_gen_weight @property def metrics(self): return [self.loss_recon_tracker, self.loss_class_tracker, self.loss_adv_tracker, self.loss_total_tracker, self.metric_class, self.metric_adv] def train_step(self, data): if len(data) == 3: (images, clusters), (_, labels), sample_weights = data else: (images, clusters), (_, labels) = data sample_weights = None # Train adversary encoder_outs = self.encoder(images, training=True) with tf.GradientTape() as gt: pred_cluster = self.adversary(encoder_outs) loss_adv = self.loss_adv(clusters, pred_cluster, sample_weight=sample_weights) grads_adv = gt.gradient(loss_adv, self.adversary.trainable_variables) self.opt_adversary.apply_gradients(zip(grads_adv, self.adversary.trainable_variables)) # Update adversarial loss tracker self.metric_adv.update_state(clusters, pred_cluster) self.loss_adv_tracker.update_state(loss_adv) # Train autoencoder with tf.GradientTape(persistent=True) as gt2: pred_recon, pred_class, pred_cluster = self((images, clusters), training=True) loss_class = self.loss_class(labels, pred_class, sample_weight=sample_weights) loss_recon = self.loss_recon(images, pred_recon, sample_weight=sample_weights) loss_adv = self.loss_adv(clusters, pred_cluster, sample_weight=sample_weights) total_loss = (self.loss_recon_weight * loss_recon) \ + (self.loss_class_weight * loss_class) \ - (self.loss_gen_weight * loss_adv) lsWeights = self.encoder.trainable_variables + self.decoder.trainable_variables \ + self.classifier.trainable_variables grads_aec = gt2.gradient(total_loss, lsWeights) self.opt_autoencoder.apply_gradients(zip(grads_aec, lsWeights)) # Update loss trackers self.metric_class.update_state(labels, pred_class) self.loss_class_tracker.update_state(loss_class) self.loss_recon_tracker.update_state(loss_recon) self.loss_total_tracker.update_state(total_loss) return {m.name: m.result() for m in self.metrics} def test_step(self, data): (images, clusters), (_, labels) = data pred_recon, pred_class, pred_cluster = self((images, clusters), training=False) loss_class = self.loss_class(labels, pred_class) loss_recon = self.loss_recon(images, pred_recon) loss_adv = self.loss_adv(clusters, pred_cluster) total_loss = (self.loss_recon_weight * loss_recon) \ + (self.loss_class_weight * loss_class) \ - (self.loss_gen_weight * loss_adv) self.metric_class.update_state(labels, pred_class) self.metric_adv.update_state(clusters, pred_cluster) self.loss_class_tracker.update_state(loss_class) self.loss_recon_tracker.update_state(loss_recon) self.loss_adv_tracker.update_state(loss_adv) self.loss_total_tracker.update_state(total_loss) return {m.name: m.result() for m in self.metrics} def load_weights_base_aec(model, weights_path: str): """Loads weights into BaseAutoencoderClassifier. When using the native Keras model.load_weights(), it fails to match up the saved weights with the model layers. This workaround manually matches them up by name. Args: model (tf.keras.Model): model weights_path (str): path to weights """ import h5py h5Weights = h5py.File(weights_path, 'r') for strSubModel in h5Weights.keys(): submodel = model.get_layer(strSubModel) lsWeightsToSet = [] for weights in submodel.weights: strWeightName = weights.name.split('autoencoder/')[1] def match_weight(name): if strWeightName in name: return name # Recursively search for the weight name in the H5 object w = h5Weights[strSubModel].visit(match_weight) lsWeightsToSet += [h5Weights[strSubModel][w]] submodel.set_weights(lsWeightsToSet) class LatentClassifier(tkl.Layer): def __init__(self, n_clusters: int, name='latent_classifier', **kwargs): """Dense classifier to predict cluster from a vector latent representation. Args: n_clusters (int): Number of clusters. name (str, optional): Name. Defaults to 'latent_classifier'. """ super(LatentClassifier, self).__init__(name=name, **kwargs) self.n_clusters = n_clusters self.dense_target = tkl.Dense(32, activation='relu', name=name + '_dense_target') self.dense_cluster0 = tkl.Dense(32, activation='relu', name=name + '_dense_cluster0') self.dense_cluster1 = tkl.Dense(16, activation='relu', name=name + '_dense_cluster1') self.out_cluster = tkl.Dense(n_clusters, activation='softmax', name=name + '_cluster') self.out_target = tkl.Dense(1, activation='sigmoid', name=name + '_target') def call(self, inputs, training=None): y = self.dense_target(inputs) y = self.out_target(y) c = self.dense_cluster0(inputs) c = self.dense_cluster1(c) c = self.out_cluster(c) return y, c def get_config(self): return {'n_clusters': self.n_clusters} class ImageClassifier(tkl.Layer): def __init__(self, n_clusters: int, layer_filters=[16, 32, 32, 64, 64, 128, 256], name='discriminator', **kwargs): """Classifier to predict cluster from an image (reconstruction). Args: n_clusters (int): Number of clusters layer_filters (list, optional): Convolutional filters in each of 7 layers. Defaults to [16, 32, 32, 64, 64, 128, 256]. name (str, optional): Name. Defaults to 'discriminator'. """ super(ImageClassifier, self).__init__(name=name, **kwargs) self.n_clusters = n_clusters self.layer_filters = layer_filters self.blocks = [] for iLayer, nFilters in enumerate(layer_filters): conv = tkl.Conv2D(nFilters, 4, strides=(2, 2), padding='same', name=name + f'_conv{iLayer}') bn = tkl.BatchNormalization(name=name + f'_bn{iLayer}') prelu = tkl.PReLU(name=name + f'_prelu{iLayer}') self.blocks += [(conv, bn, prelu)] # At this point, the data dimension should be 2 x 2 self.cluster_head = [ tkl.Conv2D(n_clusters, 2, padding='valid', name=name + '_convout'), tkl.Flatten(name=name + '_flatten'), tkl.Softmax(name=name + '_softmax') ] def call(self, inputs, training=None): x = inputs for conv, bn, prelu in self.blocks: x = conv(x) x = bn(x, training=training) x = prelu(x) conv, flat, act = self.cluster_head c = conv(x) c = flat(c) pred_cluster = act(c) return pred_cluster def get_config(self): return {'n_clusters': self.n_clusters, 'layer_filters': self.layer_filters} class RandomEffectEncoder(Encoder): def __init__(self, n_latent_dims: int=56, layer_filters: list=[64, 64, 64, 128, 256, 512], post_loc_init_scale: float=0.1, prior_scale: float=0.25, kl_weight: float=1e-5, name='encoder', **kwargs): """Encoder with random effect cluster-specific scales and biases for each convolutional filter. Args: n_latent_dims (int, optional): Dimensionality of latent representation output. Defaults to 56. layer_filters (list, optional): Convolutional filters for each of the 6 layers. Defaults to [64, 64, 64, 128, 256, 512]. post_loc_init_scale (float, optional): S.d. for random normal initialization of posteriors. Defaults to 0.1. prior_scale (float, optional): S.d. of normal prior distributions. Defaults to 0.25. kl_weight (float, optional): KL Divergence loss weight. Defaults to 1e-5. name (str, optional): Model name. Defaults to 'encoder'. """ super(RandomEffectEncoder, self).__init__(n_latent_dims=n_latent_dims, layer_filters=layer_filters, name=name, **kwargs) # Replace batch norm layers with cluster-specific scale/bias layers self.re0 = ClusterScaleBiasBlock(self.layer_filters[0], post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight, name=name + '_re0') self.re1 = ClusterScaleBiasBlock(self.layer_filters[1], post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight, name=name + '_re1') self.re2 = ClusterScaleBiasBlock(self.layer_filters[2], post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight, name=name + '_re2') self.re3 = ClusterScaleBiasBlock(self.layer_filters[3], post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight, name=name + '_re3') self.re4 = ClusterScaleBiasBlock(self.layer_filters[4], post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight, name=name + '_re4') self.re5 = ClusterScaleBiasBlock(self.layer_filters[5], post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight, name=name + '_re5') del self.bn0, self.bn1, self.bn2, self.bn3, self.bn4, self.bn5, self.bn_out def call(self, inputs, training=None): x, z = inputs x = self.conv0(x) x = self.re0((x, z), training=training) x = self.prelu0(x) x = self.conv1(x) x = self.re1((x, z), training=training) x = self.prelu1(x) x = self.conv2(x) x = self.re2((x, z), training=training) x = self.prelu2(x) x = self.conv3(x) x = self.re3((x, z), training=training) x = self.prelu3(x) x = self.conv4(x) x = self.re4((x, z), training=training) x = self.prelu4(x) x = self.conv5(x) x = self.re5((x, z), training=training) x = self.prelu5(x) x = self.flatten(x) x = self.dense(x) return x class RandomEffectDecoder(Decoder): def __init__(self, image_shape: tuple=(256, 256, 1), layer_filters: list=[512, 256, 128, 64, 64, 64], post_loc_init_scale: float=0.1, prior_scale: float=0.25, kl_weight: float=1e-5, name='decoder', **kwargs): """Decoder with random effect cluster-specific scales and biases for each convolutional filter. Args: image_shape (tuple, optional): Shape of reconstructed image. Defaults to (256, 256, 1). layer_filters (list, optional): Convolutional filters for each of the 6 layers. Defaults to [64, 64, 64, 128, 256, 512]. post_loc_init_scale (float, optional): S.d. for random normal initialization of posteriors. Defaults to 0.1. prior_scale (float, optional): S.d. of normal prior distributions. Defaults to 0.25. kl_weight (float, optional): KL Divergence loss weight. Defaults to 1e-5. name (str, optional): Model name. Defaults to 'encoder'. """ super(RandomEffectDecoder, self).__init__(image_shape=image_shape, layer_filters=layer_filters, name=name, **kwargs) # Replace batch norm layers with cluster-specific scale/bias layers self.re0 = ClusterScaleBiasBlock(self.layer_filters[1], post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight, name=name + '_re0') self.re1 = ClusterScaleBiasBlock(self.layer_filters[2], post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight, name=name + '_re1') self.re2 = ClusterScaleBiasBlock(self.layer_filters[3], post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight, name=name + '_re2') self.re3 = ClusterScaleBiasBlock(self.layer_filters[4], post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight, name=name + '_re3') self.re4 = ClusterScaleBiasBlock(self.layer_filters[5], post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight, name=name + '_re4') self.re5 = ClusterScaleBiasBlock(1, post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight, name=name + '_re5') del self.bn0, self.bn1, self.bn2, self.bn3, self.bn4, self.bn5 def call(self, inputs, training=None): x, z = inputs x = self.dense(x) x = self.reshape(x) x = self.prelu_dense(x) x = self.tconv0(x) x = self.re0((x, z), training=training) x = self.prelu0(x) x = self.tconv1(x) x = self.re1((x, z), training=training) x = self.prelu1(x) x = self.tconv2(x) x = self.re2((x, z), training=training) x = self.prelu2(x) x = self.tconv3(x) x = self.re3((x, z), training=training) x = self.prelu3(x) x = self.tconv4(x) x = self.re4((x, z), training=training) x = self.prelu4(x) x = self.tconv5(x) x = self.re5((x, z), training=training) x = self.sigmoid_out(x) return x class DomainEnhancingAutoencoderClassifier(tf.keras.Model): def __init__(self, image_shape: tuple=(256, 256, 1), n_clusters: int=10, n_latent_dims: int=56, encoder_layer_filters: list=[64, 64, 64, 128, 256, 512], post_loc_init_scale: float=0.1, prior_scale: float=0.25, kl_weight: float=1e-5, name='autoencoder', **kwargs): """ Autoencoder that emphasizes cluster differences in both the compressed latent representation and the reconstructed image. This is done by adding random effects layers to the encoder and decoder, as well as using additional classifiers to maximize the cluster-predictive information present in the latents and reconstructions. Args: image_shape (tuple, optional): Input image size. Defaults to (256, 256, 1). n_clusters (int, optional): Number of clusters. Defaults to 10. n_latent_dims (int, optional): Dimensionality of latent representations. Defaults to 56. encoder_layer_filters (list, optional): Convolutional filters in each of the 6 encoder layers. Defaults to [64, 64, 64, 128, 256, 512]. post_loc_init_scale (float, optional): S.d. for random normal initialization of posteriors. Defaults to 0.1. prior_scale (float, optional): S.d. of normal prior distributions. Defaults to 0.25. kl_weight (float, optional): KL Divergence loss weight. Defaults to 1e-5. name (str, optional): Model name. Defaults to 'autoencoder'. """ super(DomainEnhancingAutoencoderClassifier, self).__init__(name=name, **kwargs) self.image_shape = image_shape self.n_latent_dims = n_latent_dims self.encoder_layer_filters = encoder_layer_filters # Decoder layers should mirror the encoder layers self.decoder_layer_filters = encoder_layer_filters[-1::-1] self.encoder = RandomEffectEncoder(n_latent_dims=n_latent_dims, layer_filters=encoder_layer_filters, post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight) self.decoder = RandomEffectDecoder(layer_filters=self.decoder_layer_filters, post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight) # Classifiers to guide the latents and reconstructions to producing outputs # that are laden with cluster-predictive information self.latent_classifier = LatentClassifier(n_clusters=n_clusters) self.image_classifier = ImageClassifier(n_clusters=n_clusters) def call(self, inputs, training=None): if len(inputs) != 2: raise ValueError('Model inputs need to be a tuple of (images, clusters)') x, z = inputs # Call encoder and get latents latent = self.encoder((x, z), training=training) # Predict cluster from latents pred_y, pred_c_latent = self.latent_classifier(latent) # Reconstruct image from latents recon = self.decoder((latent, z), training=training) # Predict cluster from reconstruction pred_c_recon = self.image_classifier(recon) return recon, pred_y, pred_c_latent, pred_c_recon def compile(self, loss_recon=tf.keras.losses.MeanSquaredError(), loss_class=tf.keras.losses.BinaryCrossentropy(), loss_cluster=tf.keras.losses.CategoricalCrossentropy(), metric_class=tf.keras.metrics.AUC(name='auroc'), optimizer=tf.keras.optimizers.Adam(lr=0.0001), loss_class_weight=0.01, loss_latent_cluster_weight=0.001, loss_image_cluster_weight=0.001 ): """Compile model with given losses, metrics, and optimizer. The autoencoder-classifier is trained with this loss function: loss_recon (image reconstruction loss) + loss_class_weight * loss_class (phenotype classification loss) + loss_latent_cluster_weight * loss_cluster(latent classifier) (cluster predictiveness of latents) + loss_image_cluster_weight * loss_cluster(recon classifier) (cluster predictiveness of reconstructions) Args: loss_recon (loss, optional): Image reconstruction loss. Defaults to tf.keras.losses.MeanSquaredError(). loss_class (loss, optional): Auxiliary classification loss. Defaults to tf.keras.losses.BinaryCrossentropy(). loss_cluster (loss, optional): Cluster classification loss. Defaults to tf.keras.losses.CategoricalCrossentropy(). metric_class (metric, optional): Auxiliary classification metric. Defaults to tf.keras.metrics.AUC(name='auroc'). optimizer (optimizer, optional): Optimizer. Defaults to tf.keras.optimizers.Adam(lr=0.0001). loss_class_weight (float, optional): Weight for auxiliary classification loss. Defaults to 0.01. loss_latent_cluster_weight (float, optional): Weight for cluster prediction loss for latents. Defaults to 0.001. loss_image_cluster_weight (float, optional): Weight for cluster prediction loss for recons. Defaults to 0.001. """ super().compile() self.loss_recon = loss_recon self.loss_class = loss_class self.loss_cluster = loss_cluster self.optimizer = optimizer self.metric_class = metric_class self.loss_class_weight = loss_class_weight self.loss_latent_cluster_weight = loss_latent_cluster_weight self.loss_image_cluster_weight = loss_image_cluster_weight # Loss trackers (running means) self.loss_recon_tracker = tf.keras.metrics.Mean(name='recon_loss') self.loss_class_tracker = tf.keras.metrics.Mean(name='class_loss') self.loss_latent_cluster_tracker = tf.keras.metrics.Mean(name='la_clus_loss') self.loss_image_cluster_tracker = tf.keras.metrics.Mean(name='im_clus_loss') self.loss_kl_tracker = tf.keras.metrics.Mean(name='kld') self.loss_total_tracker = tf.keras.metrics.Mean(name='total_loss') @property def metrics(self): return [self.loss_recon_tracker, self.loss_class_tracker, self.metric_class, self.loss_latent_cluster_tracker, self.loss_image_cluster_tracker, self.loss_kl_tracker, self.loss_total_tracker] def _compute_update_loss(self, loss_recon, loss_class, loss_latent_cluster, loss_image_cluster, training=True): '''Compute total loss and update loss running means''' self.loss_recon_tracker.update_state(loss_recon) self.loss_class_tracker.update_state(loss_class) self.loss_latent_cluster_tracker.update_state(loss_latent_cluster) self.loss_image_cluster_tracker.update_state(loss_image_cluster) if training: kld = tf.reduce_mean(self.encoder.losses) + tf.reduce_mean(self.decoder.losses) self.loss_kl_tracker.update_state(kld) else: # KLD can't be computed at inference time because posteriors are simplified to # point estimates kld = 0 loss_total = loss_recon \ + (self.loss_class_weight * loss_class) \ + (self.loss_latent_cluster_weight * loss_latent_cluster) \ + (self.loss_image_cluster_weight * loss_image_cluster) \ + kld self.loss_total_tracker.update_state(loss_total) return loss_total def train_step(self, data): if len(data) == 3: (images, clusters), (_, labels), sample_weights = data else: (images, clusters), (_, labels) = data sample_weights = None # Train the recon image-cluster classifier on real images with tf.GradientTape() as gt: # Predict cluster from real images pred_c_image = self.image_classifier(images) loss_image_cluster = self.loss_cluster(clusters, pred_c_image) grads = gt.gradient(loss_image_cluster, self.image_classifier.trainable_variables) self.optimizer.apply_gradients(zip(grads, self.image_classifier.trainable_variables)) # Train the rest of the model with tf.GradientTape() as gt: recon, pred_y, pred_c_latent, pred_c_recon = self((images, clusters), training=True) loss_recon = self.loss_recon(images, recon) loss_class = self.loss_class(labels, pred_y) loss_latent_cluster = self.loss_cluster(clusters, pred_c_latent) loss_image_cluster = self.loss_cluster(clusters, pred_c_recon) loss_total = self._compute_update_loss(loss_recon, loss_class, loss_latent_cluster, loss_image_cluster) lsWeights = self.encoder.trainable_variables + self.decoder.trainable_variables \ + self.latent_classifier.trainable_variables grads = gt.gradient(loss_total, lsWeights) self.optimizer.apply_gradients(zip(grads, lsWeights)) # Update metrics self.metric_class.update_state(labels, pred_y) return {m.name: m.result() for m in self.metrics} def test_step(self, data): (images, clusters), (_, labels) = data recon, pred_y, pred_c_latent, pred_c_recon = self((images, clusters), training=False) loss_recon = self.loss_recon(images, recon) loss_class = self.loss_class(labels, pred_y) loss_latent_cluster = self.loss_cluster(clusters, pred_c_latent) loss_image_cluster = self.loss_cluster(clusters, pred_c_recon) _ = self._compute_update_loss(loss_recon, loss_class, loss_latent_cluster, loss_image_cluster, training=False) self.metric_class.update_state(labels, pred_y) return {m.name: m.result() for m in self.metrics} class MixedEffectAuxClassifier(tkl.Layer): def __init__(self, units: int=32, post_loc_init_scale: float=0.1, prior_scale: float=0.25, kl_weight: float=1e-5, name='auxclassifier', **kwargs): """Mixed effects dense classifier with one hidden layer and sigmoid output. Args: units (int, optional): Number of hidden layer neurons. Defaults to 32. post_loc_init_scale (float, optional): S.d. for random normal initialization of posteriors. Defaults to 0.1. prior_scale (float, optional): S.d. of normal prior distributions. Defaults to 0.25. kl_weight (float, optional): KL Divergence loss weight. Defaults to 1e-5. name (str, optional): Name. Defaults to 'auxclassifier'. """ super().__init__(name=name, **kwargs) self.units = units self.hidden = tkl.Dense(units, name=name + '_dense') self.activation = tkl.LeakyReLU(name=name + '_leakyrelu') self.re_slopes = RandomEffects(units, post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight, name='re_slopes') self.dense_out = tkl.Dense(1, name=name + '_output') self.re_intercept = RandomEffects(1, post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight, name='re_intercept') self.sigmoid = tkl.Activation('sigmoid', name='sigmoid') def call(self, inputs, training=None): x, z = inputs x = self.hidden(x) x = self.activation(x) g = self.re_slopes(z, training=training) x = self.dense_out((1 + g) * x) b = self.re_intercept(z, training=training) y = self.sigmoid(b + x) return y def get_config(self): return {'units': self.units} class MixedEffectsAEC(DomainAdversarialAEC): ''' Still under development!''' def __init__(self, image_shape: tuple=(256, 256, 1), n_clusters: int=10, n_latent_dims: int=56, encoder_layer_filters: list=[64, 64, 64, 128, 256, 512], classifier_hidden_units: int=32, adversary_layer_filters: list=[16, 32, 32, 64, 64, 128, 128, 256], post_loc_init_scale: float=0.1, prior_scale: float=0.25, kl_weight: float=1e-5, name='autoencoder', **kwargs ): """Mixed effects autoencoder-classifier. Adds an adversarial classifier to predict the cluster/domain membership of each sample based on the encoder's intermediate outputs. This compels the encoder to learn features unassociated with cluster characteristics. Then, random effects are learned in the decoder through cluster-specific feature scales and biases and in the auxiliary classifier through cluster -specific slopes and intercepts. Args: image_shape (tuple, optional): Image shape. Defaults to (256, 256, 1). n_clusters (int, optional): Number of clusters. Defaults to 10. n_latent_dims (int, optional): Size of latent representation. Defaults to 56. encoder_layer_filters (list, optional): Number of filters in each encoder layer. Defaults to [64, 64, 64, 128, 256, 512]. classifier_hidden_units (int, optional): Number of neurons in auxiliary classifier hidden layer. Defaults to 32. adversary_layer_filters (list, optional): Number of filters in each adversary layer. Defaults to [16, 32, 32, 64, 64, 128, 128, 256]. post_loc_init_scale (float, optional): S.d. for random normal initialization of posteriors. Defaults to 0.1. prior_scale (float, optional): S.d. of normal prior distributions. Defaults to 0.25. kl_weight (float, optional): KL Divergence loss weight. Defaults to 1e-5. name (str, optional): Model name. Defaults to 'autoencoder'. """ super(BaseAutoencoderClassifier, self).__init__( name=name, **kwargs) self.image_shape = image_shape self.n_latent_dims = n_latent_dims self.encoder_layer_filters = encoder_layer_filters self.decoder_layer_filters = encoder_layer_filters[-1::-1] self.classifier_hidden_units = classifier_hidden_units self.encoder = Encoder(n_latent_dims=n_latent_dims, layer_filters=encoder_layer_filters, return_layer_activations=True) lsEncoderLayers = [self.encoder.conv0, self.encoder.conv1, self.encoder.conv2, self.encoder.conv3, self.encoder.conv4, self.encoder.conv5] self.decoder_fe = TiedDecoder(lsEncoderLayers, image_shape=image_shape, layer_filters=self.decoder_layer_filters, name='decoder_fe') self.decoder_re = RandomEffectDecoder(layer_filters=self.decoder_layer_filters, post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight, name='decoder_re') self.classifier = MixedEffectAuxClassifier(units=classifier_hidden_units, post_loc_init_scale=post_loc_init_scale, prior_scale=prior_scale, kl_weight=kl_weight) self.adversary = AdversarialClassifier(image_shape, n_clusters, layer_filters=adversary_layer_filters) self.recon_cluster_classifier = ImageClassifier(n_clusters=n_clusters, name='recon_classifier') def call(self, inputs, training=None): images, clusters = inputs encoder_outs = self.encoder(images, training=training) latent = encoder_outs[-1] recon_re = self.decoder_re((latent, clusters), training=training) recon_fe = self.decoder_fe(latent, training=training) classification = self.classifier((latent, clusters), training=training) pred_cluster = self.adversary(encoder_outs) return (recon_re, recon_fe, classification, pred_cluster) def compile(self, loss_recon=tf.keras.losses.MeanSquaredError(), loss_class=tf.keras.losses.BinaryCrossentropy(), loss_adv=tf.keras.losses.CategoricalCrossentropy(), metric_class=tf.keras.metrics.AUC(name='auroc'), metric_adv=tf.keras.metrics.CategoricalAccuracy(name='acc'), opt_autoencoder=tf.keras.optimizers.Adam(lr=0.0001), opt_adversary=tf.keras.optimizers.Adam(lr=0.0001), opt_recon_classifier=tf.keras.optimizers.Adam(lr=0.0001), loss_recon_weight=1, loss_recon_fe_weight=1, loss_class_weight=0.01, loss_gen_weight=0.2, loss_recon_cluster_weight=0.01): super().compile(loss_recon, loss_class, loss_adv, metric_class, metric_adv, opt_autoencoder, opt_adversary, loss_recon_weight, loss_class_weight, loss_gen_weight) self.loss_recon_fe_weight = loss_recon_fe_weight self.loss_recon_fe_tracker = tf.keras.metrics.Mean('recon_fe_loss') self.loss_kl_tracker = tf.keras.metrics.Mean('kld') self.loss_recon_cluster_weight = loss_recon_cluster_weight self.loss_recon_cluster_tracker = tf.keras.metrics.Mean('recon_clus_loss') self.opt_recon_classifier = opt_recon_classifier @property def metrics(self): return [self.loss_recon_tracker, self.loss_class_tracker, self.loss_adv_tracker, self.loss_kl_tracker, self.loss_recon_cluster_tracker, self.loss_total_tracker, self.metric_class, self.metric_adv] def _compute_update_loss(self, loss_recon_re, loss_recon_fe, loss_class, loss_gen, loss_recon_cluster, training=True): '''Compute total loss and update loss running means''' self.loss_recon_tracker.update_state(loss_recon_re) self.loss_recon_fe_tracker.update_state(loss_recon_fe) self.loss_class_tracker.update_state(loss_class) self.loss_recon_cluster_tracker.update_state(loss_recon_cluster) if training: kld = tf.reduce_mean(self.decoder_re.losses) + tf.reduce_mean(self.classifier.losses) self.loss_kl_tracker.update_state(kld) else: # KLD can't be computed at inference time because posteriors are simplified to # point estimates kld = 0 loss_total = (self.loss_recon_weight * loss_recon_re) \ + (self.loss_recon_fe_weight * loss_recon_fe) \ + (self.loss_class_weight * loss_class) \ + (self.loss_gen_weight * loss_gen) \ + (self.loss_recon_cluster_weight * loss_recon_cluster) \ + kld self.loss_total_tracker.update_state(loss_total) return loss_total def train_step(self, data): if len(data) == 3: (images, clusters), (_, labels), sample_weights = data else: (images, clusters), (_, labels) = data sample_weights = None # Train adversarial classifier encoder_outs = self.encoder(images, training=True) with tf.GradientTape() as gt: pred_cluster = self.adversary(encoder_outs) loss_adv = self.loss_adv(clusters, pred_cluster, sample_weight=sample_weights) grads_adv = gt.gradient(loss_adv, self.adversary.trainable_variables) self.opt_adversary.apply_gradients(zip(grads_adv, self.adversary.trainable_variables)) self.metric_adv.update_state(clusters, pred_cluster) self.loss_adv_tracker.update_state(loss_adv) # Train the recon-cluster classifier on real images with tf.GradientTape() as gt2: # Predict cluster from real images pred_c_image = self.recon_cluster_classifier(images) loss_image_cluster = self.loss_adv(clusters, pred_c_image, sample_weight=sample_weights) grads_rc = gt2.gradient(loss_image_cluster, self.recon_cluster_classifier.trainable_variables) self.opt_recon_classifier.apply_gradients(zip(grads_rc, self.recon_cluster_classifier.trainable_variables)) self.loss_recon_cluster_tracker.update_state(loss_image_cluster) # Train the rest of the model with tf.GradientTape(persistent=True) as gt3: pred_recon_re, pred_recon_fe, pred_class, pred_cluster = self((images, clusters), training=True) loss_class = self.loss_class(labels, pred_class, sample_weight=sample_weights) loss_recon_re = self.loss_recon(images, pred_recon_re, sample_weight=sample_weights) loss_recon_fe = self.loss_recon(images, pred_recon_fe, sample_weight=sample_weights) loss_gen = self.loss_adv(clusters, pred_cluster, sample_weight=sample_weights) pred_recon_cluster = self.recon_cluster_classifier(pred_recon_re) loss_recon_cluster = self.loss_adv(clusters, pred_recon_cluster) total_loss = self._compute_update_loss(loss_recon_re, loss_recon_fe, loss_class, loss_gen, loss_recon_cluster, training=True) lsWeights = self.encoder.trainable_variables + self.decoder_re.trainable_variables \ + self.decoder_fe.trainable_variables + self.classifier.trainable_variables grads_aec = gt3.gradient(total_loss, lsWeights) self.opt_autoencoder.apply_gradients(zip(grads_aec, lsWeights)) self.metric_class.update_state(labels, pred_class) return {m.name: m.result() for m in self.metrics} def test_step(self, data): (images, clusters), (_, labels) = data pred_recon_re, pred_recon_fe, pred_class, pred_cluster = self((images, clusters), training=False) loss_class = self.loss_class(labels, pred_class) loss_recon_re = self.loss_recon(images, pred_recon_re) loss_recon_fe = self.loss_recon(images, pred_recon_fe) loss_adv = self.loss_adv(clusters, pred_cluster) pred_recon_cluster = self.recon_cluster_classifier(pred_recon_re) loss_recon_cluster = self.loss_adv(clusters, pred_recon_cluster) _ = self._compute_update_loss(loss_recon_re, loss_recon_fe, loss_class, loss_adv, loss_recon_cluster, training=False) self.metric_class.update_state(labels, pred_class) self.metric_adv.update_state(clusters, pred_cluster) self.loss_adv_tracker.update_state(loss_adv) return {m.name: m.result() for m in self.metrics}