'''
Core random effects Bayesian layers. Code adapted from Nguyen et al 2023: tinyurl.com/ARMEDCode
'''
import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as tkl
from tensorflow_probability import layers as tpl
from tensorflow_probability import distributions as tpd
from tensorflow_addons.layers import InstanceNormalization
def make_posterior_fn(post_loc_init_scale, post_scale_init_min, post_scale_init_range):
def _re_posterior_fn(kernel_size, bias_size=0, dtype=None):
n = kernel_size + bias_size
# There are n variables containing the mean of each weight and n variables
# containing the shared s.d. for all weights
initializer = tpl.BlockwiseInitializer([tf.keras.initializers.RandomNormal(mean=0,
stddev=post_loc_init_scale),
tf.keras.initializers.RandomUniform(minval=post_scale_init_min,
maxval=post_scale_init_min \
+ post_scale_init_range),
],
sizes=[n, n])
return tf.keras.Sequential([tpl.VariableLayer(n + n, dtype=dtype, initializer=initializer),
tpl.DistributionLambda(lambda t: tpd.Independent(
tpd.Normal(loc=t[..., :n], scale=1e-5 + tf.nn.softplus(t[..., n:])),
reinterpreted_batch_ndims=1))
])
return _re_posterior_fn
def make_fixed_prior_fn(prior_scale):
def _prior_fn(kernel_size, bias_size=0, dtype=None):
n = kernel_size + bias_size
return tf.keras.Sequential([tpl.DistributionLambda(lambda t:
tpd.Independent(
tpd.Normal(loc=tf.zeros(n), scale=prior_scale),
reinterpreted_batch_ndims=1))
])
return _prior_fn
def make_trainable_prior_fn(prior_scale):
def _prior_fn(kernel_size, bias_size=0, dtype=None):
n = kernel_size + bias_size
initializer = tf.initializers.Constant(prior_scale)
return tf.keras.Sequential([tpl.VariableLayer(n, dtype=dtype, initializer=initializer),
tpl.DistributionLambda(lambda t:
tpd.Normal(loc=tf.zeros(n), scale=1e-5 + tf.nn.softplus(t)))])
return _prior_fn
class RandomEffects(tpl.DenseVariational):
def __init__(self,
units: int=1,
post_loc_init_scale: float=0.05,
post_scale_init_min: float=0.05,
post_scale_init_range: float=0.05,
prior_scale: float=0.05,
kl_weight: float=0.001,
l1_weight: float=None,
name=None) -> None:
"""Core random effects layer, which learns cluster-specific parameters
regularized to a zero-mean normal distribution. It takes as input a
one-hot encoded matrix Z indicating the cluster membership of each sample,
then returns a vector of cluster-specific parameters u(Z). Each parameter
is regularized to follow zero-mean normal distribution.
Args:
units (int, optional): Number of parameters. Defaults to 1.
post_loc_init_scale (float, optional): S.d. for initializing
posterior means with a random normal distribution. Defaults to 0.05.
post_scale_init_min (float, optional): Range lower bound for
initializing posterior variances with a random uniform distribution.
Defaults to 0.05.
post_scale_init_range (float, optional): Range width for
initializing posterior variances with a random uniform distribution.
Defaults to 0.05.
prior_scale (float, optional): S.d. of prior distribution. Defaults to 0.05.
kl_weight (float, optional): KL divergence weight. Defaults to 0.001.
l1_weight (float, optional): L1 regularization weight. Defaults to None.
name (str, optional): Layer name. Defaults to None.
"""
self.kl_weight = kl_weight
self.l1_weight = l1_weight
# The posterior scale is saved as a softplus transformed weight, so we
# need to convert the given initalization args using the inverse
# softplus
fPostScaleMin = np.log(np.exp(post_scale_init_min) - 1)
fPostScaleRange = np.log(np.exp(post_scale_init_range) - 1)
posterior = make_posterior_fn(post_loc_init_scale, fPostScaleMin, fPostScaleRange)
prior = make_fixed_prior_fn(prior_scale)
super().__init__(units, posterior, prior, use_bias=False,
kl_weight=kl_weight,
name=name)
def call(self, inputs, training=None):
if training == False:
# In testing mode, use the posterior means
if self._posterior.built == False:
self._posterior.build(inputs.shape)
if self._prior.built == False:
self._prior.build(inputs.shape)
# First half of weights contains the posterior means
nWeights = self.weights[0].shape[0]
w = self.weights[0][:(nWeights // 2)]
prev_units = self.input_spec.axes[-1]
kernel = tf.reshape(w, shape=tf.concat([
tf.shape(w)[:-1],
[prev_units, self.units],
], axis=0))
#print("\nrandom effects inputs", inputs, "dtype:", inputs.dtype)
#print("\nrandom effects kernel", kernel, "dtype:", kernel.dtype)
inputs = tf.cast(inputs, tf.float32)
outputs = tf.matmul(inputs, kernel)
if self.activation is not None:
outputs = self.activation(outputs) # pylint: disable=not-callable
else:
outputs = super().call(inputs)
if self.l1_weight:
# First half of weights contains the posterior means
nWeights = self.weights[0].shape[0]
postmeans = self.weights[0][:(nWeights // 2)]
self.add_loss(self.l1_weight * tf.reduce_sum(tf.abs(postmeans)))
return outputs
class ClusterScaleBiasBlock(tf.keras.Model):
def __init__(self,
n_features,
post_loc_init_scale=0.25,
prior_scale=0.25,
kl_weight=0.001,
name='cluster',
**kwargs):
"""Layer applying cluster-specific random scales and biases to the
output of a convolution layer.
This layer learns cluster-specific scale vectors 'gamma(Z)' and bias
vectors 'beta(Z)', where Z is the one-hot. These vectors have length
equal to the number of filters in the preceding convolution layer.
After instance-normalzing the input x, the following operation is
applied:
(1 + gamma) * x + beta
Any activation function should be placed after this layer. Other
normalization layers should not be used.
Args:
n_features (int): Number of filters in preceding convolution layer.
post_loc_init_scale (float, optional): S.d. for initializing
posterior means with a random normal distribution. Defaults to 0.25.
prior_scale (float, optional): S.d. of normal prior distribution. Defaults to 0.25.
gamma_dist (bool, optional): Use a gamma prior distribution (not
fully tested). Defaults to False.
kl_weight (float, optional): KL divergence weight. Defaults to 0.001.
name (str, optional): Layer name. Defaults to 'cluster'.
"""
super(ClusterScaleBiasBlock, self).__init__(name=name, **kwargs)
self.n_features = n_features
self.post_loc_init_scale = post_loc_init_scale
self.prior_scale = prior_scale
self.kl_weight = kl_weight
# self.instance_norm = InstanceNormalization(center=True,
# scale=True,
# name=name + '_instance_norm')
self.bn = tf.keras.layers.BatchNormalization(name = name + '_batch_norm')
self.gammas = RandomEffects(n_features,
post_loc_init_scale=post_loc_init_scale,
post_scale_init_min=0.01,
post_scale_init_range=0.005,
prior_scale=prior_scale,
kl_weight=kl_weight,
name=name + '_gammas')
self.betas = RandomEffects(n_features,
post_loc_init_scale=post_loc_init_scale,
post_scale_init_min=0.01,
post_scale_init_range=0.005,
prior_scale=prior_scale,
kl_weight=kl_weight,
name=name + '_betas')
def call(self, inputs, training=None):
x, z = inputs
# x = self.instance_norm(x)
# batch normalization. There was a bug when using instance_norm
x = self.bn(x,training=training)
g = self.gammas(z, training=training)
b = self.betas(z, training=training)
# Ensure shape is batch_size x 1 x 1 x n_features
if len(tf.shape(x)) > 2:
new_dims = len(tf.shape(x)) - 2
g = tf.reshape(g, [-1] + [1] * new_dims + [self.n_features])
b = tf.reshape(b, [-1] + [1] * new_dims + [self.n_features])
m = x * (1 + g)
s = m + b
return s
def get_config(self):
return {'post_loc_init_scale': self.post_loc_init_scale,
'prior_scale': self.prior_scale,
'gamma_dist': self.gamma_dist,
'kl_weight': self.kl_weight}