diff --git a/mlpp_lib/probabilistic_layers.py b/mlpp_lib/probabilistic_layers.py index 01bca9e..615e4c2 100644 --- a/mlpp_lib/probabilistic_layers.py +++ b/mlpp_lib/probabilistic_layers.py @@ -1,4 +1,5 @@ """In this module, any custom built keras layers are included.""" + import numpy as np import tensorflow as tf import tensorflow_probability.python.layers as tfpl @@ -21,6 +22,7 @@ IndependentPoisson, ) + @tf.keras.saving.register_keras_serializable() class IndependentNormal(IndependentNormal): @property @@ -51,7 +53,7 @@ def output(self): @tf.keras.saving.register_keras_serializable() class IndependentBeta(tfpl.DistributionLambda): - """An independent 4-parameter Beta Keras layer.""" + """An independent 2-parameter Beta Keras layer""" def __init__( self, @@ -96,6 +98,115 @@ def __init__( def new(params, event_shape=(), validate_args=False, name=None): """Create the distribution instance from a `params` vector.""" with tf.name_scope(name or "IndependentBeta"): + params = tf.convert_to_tensor(params, name="params") + event_shape = dist_util.expand_to_vector( + tf.convert_to_tensor( + event_shape, name="event_shape", dtype_hint=tf.int32 + ), + tensor_name="event_shape", + ) + output_shape = tf.concat( + [ + tf.shape(params)[:-1], + event_shape, + ], + axis=0, + ) + alpha, beta = tf.split(params, 2, axis=-1) + + alpha = tf.math.softplus(tf.reshape(alpha, output_shape)) + 1e-3 + beta = tf.math.softplus(tf.reshape(beta, output_shape)) + 1e-3 + betad = tfd.Beta(alpha, beta, validate_args=validate_args) + + return independent_lib.Independent( + betad, + reinterpreted_batch_ndims=tf.size(event_shape), + validate_args=validate_args, + ) + + @staticmethod + def params_size(event_shape=(), name=None): + """The number of `params` needed to create a single distribution.""" + with tf.name_scope(name or "IndependentBeta_params_size"): + event_shape = tf.convert_to_tensor( + event_shape, name="event_shape", dtype_hint=tf.int32 + ) + return np.int32(2) * _event_size( + event_shape, name=name or "IndependentBeta_params_size" + ) + + def get_config(self): + """Returns the config of this layer. + NOTE: At the moment, this configuration can only be serialized if the + Layer's `convert_to_tensor_fn` is a serializable Keras object (i.e., + implements `get_config`) or one of the standard values: + - `Distribution.sample` (or `"sample"`) + - `Distribution.mean` (or `"mean"`) + - `Distribution.mode` (or `"mode"`) + - `Distribution.stddev` (or `"stddev"`) + - `Distribution.variance` (or `"variance"`) + """ + config = { + "event_shape": self._event_shape, + "convert_to_tensor_fn": _serialize(self._convert_to_tensor_fn), + "validate_args": self._validate_args, + } + base_config = super(IndependentBeta, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @property + def output(self): + """This allows the use of this layer with the shap package.""" + return super(IndependentBeta, self).output[0] + + +@tf.keras.saving.register_keras_serializable() +class Independent4ParamsBeta(tfpl.DistributionLambda): + """An independent 4-parameter Beta Keras layer allowing control over scale as well as a 'shift' parameter.""" + + def __init__( + self, + event_shape=(), + convert_to_tensor_fn=tfd.Distribution.mean, + validate_args=False, + **kwargs + ): + """Initialize the `Independent4ParamsBeta` layer. + Args: + event_shape: integer vector `Tensor` representing the shape of single + draw from this distribution. + convert_to_tensor_fn: Python `callable` that takes a `tfd.Distribution` + instance and returns a `tf.Tensor`-like object. + Default value: `tfd.Distribution.mean`. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + Default value: `False`. + **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + """ + convert_to_tensor_fn = _get_convert_to_tensor_fn(convert_to_tensor_fn) + + # If there is a 'make_distribution_fn' keyword argument (e.g., because we + # are being called from a `from_config` method), remove it. We pass the + # distribution function to `DistributionLambda.__init__` below as the first + # positional argument. + kwargs.pop("make_distribution_fn", None) + + super(Independent4ParamsBeta, self).__init__( + lambda t: Independent4ParamsBeta.new(t, event_shape, validate_args), + convert_to_tensor_fn, + **kwargs + ) + + self._event_shape = event_shape + self._convert_to_tensor_fn = convert_to_tensor_fn + self._validate_args = validate_args + + @staticmethod + def new(params, event_shape=(), validate_args=False, name=None): + """Create the distribution instance from a `params` vector.""" + with tf.name_scope(name or "Independent4ParamsBeta"): params = tf.convert_to_tensor(params, name="params") event_shape = dist_util.expand_to_vector( tf.convert_to_tensor( @@ -112,10 +223,10 @@ def new(params, event_shape=(), validate_args=False, name=None): ) alpha, beta, shift, scale = tf.split(params, 4, axis=-1) # alpha > 2 and beta > 2 produce a concave downward Beta - alpha = 2.0 + tf.math.softplus(tf.reshape(alpha, output_shape)) - beta = 2.0 + tf.math.softplus(tf.reshape(beta, output_shape)) + alpha = tf.math.softplus(tf.reshape(alpha, output_shape)) + 1e-3 + beta = tf.math.softplus(tf.reshape(beta, output_shape)) + 1e-3 shift = tf.math.softplus(tf.reshape(shift, output_shape)) - scale = tf.math.softplus(tf.reshape(scale, output_shape)) + scale = tf.math.softplus(tf.reshape(scale, output_shape)) + 1e-3 betad = tfd.Beta(alpha, beta, validate_args=validate_args) transf_betad = tfd.TransformedDistribution( distribution=betad, bijector=tfb.Shift(shift)(tfb.Scale(scale)) @@ -129,12 +240,12 @@ def new(params, event_shape=(), validate_args=False, name=None): @staticmethod def params_size(event_shape=(), name=None): """The number of `params` needed to create a single distribution.""" - with tf.name_scope(name or "IndependentBeta_params_size"): + with tf.name_scope(name or "Independent4ParamsBeta_params_size"): event_shape = tf.convert_to_tensor( event_shape, name="event_shape", dtype_hint=tf.int32 ) return np.int32(4) * _event_size( - event_shape, name=name or "IndependentBeta_params_size" + event_shape, name=name or "Independent4ParamsBeta_params_size" ) def get_config(self): @@ -153,13 +264,279 @@ def get_config(self): "convert_to_tensor_fn": _serialize(self._convert_to_tensor_fn), "validate_args": self._validate_args, } - base_config = super(IndependentBeta, self).get_config() + base_config = super(Independent4ParamsBeta, self).get_config() return dict(list(base_config.items()) + list(config.items())) @property def output(self): """This allows the use of this layer with the shap package.""" - return super(IndependentBeta, self).output[0] + return super(Independent4ParamsBeta, self).output[0] + + +@tf.keras.saving.register_keras_serializable() +class IndependentDoublyCensoredNormal(tfpl.DistributionLambda): + """An independent censored normal Keras layer.""" + + def __init__( + self, + event_shape=(), + convert_to_tensor_fn=tfd.Distribution.mean, + validate_args=False, + **kwargs + ): + """Initialize the `IndependentDoublyCensoredNormal` layer. + Args: + event_shape: integer vector `Tensor` representing the shape of single + draw from this distribution. + convert_to_tensor_fn: Python `callable` that takes a `tfd.Distribution` + instance and returns a `tf.Tensor`-like object. + Default value: `tfd.Distribution.mean`. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + Default value: `False`. + **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + """ + convert_to_tensor_fn = _get_convert_to_tensor_fn(convert_to_tensor_fn) + + # If there is a 'make_distribution_fn' keyword argument (e.g., because we + # are being called from a `from_config` method), remove it. We pass the + # distribution function to `DistributionLambda.__init__` below as the first + # positional argument. + kwargs.pop("make_distribution_fn", None) + + super(IndependentDoublyCensoredNormal, self).__init__( + lambda t: IndependentDoublyCensoredNormal.new( + t, event_shape, validate_args + ), + convert_to_tensor_fn, + **kwargs + ) + + self._event_shape = event_shape + self._convert_to_tensor_fn = convert_to_tensor_fn + self._validate_args = validate_args + + @staticmethod + def new(params, event_shape=(), validate_args=False, name=None): + """Create the distribution instance from a `params` vector.""" + with tf.name_scope(name or "IndependentDoublyCensoredNormal"): + params = tf.convert_to_tensor(params, name="params") + event_shape = dist_util.expand_to_vector( + tf.convert_to_tensor( + event_shape, name="event_shape", dtype_hint=tf.int32 + ), + tensor_name="event_shape", + ) + output_shape = tf.concat( + [ + tf.shape(params)[:-1], + event_shape, + ], + axis=0, + ) + loc, scale = tf.split(params, 2, axis=-1) + loc = tf.reshape(loc, output_shape) + scale = tf.math.softplus(tf.reshape(scale, output_shape)) + 1e-6 + normal_dist = tfd.Normal(loc=loc, scale=scale, validate_args=validate_args) + + class CustomCensored(tfd.Distribution): + def __init__(self, normal): + self.normal = normal + super(CustomCensored, self).__init__( + dtype=normal.dtype, + reparameterization_type=tfd.FULLY_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=True, + ) + + def _sample_n(self, n, seed=None): + + # Sample from normal distribution + samples = self.normal.sample(sample_shape=(n,), seed=seed) + + # Clip values between 0 and 1 + chosen_samples = tf.clip_by_value(samples, 0, 1) + + return chosen_samples + + def _mean(self): + original_mean = self.normal.mean() + low_bound_standard = (0 - original_mean) / self.normal.stddev() + high_bound_standard = (1 - original_mean) / self.normal.stddev() + + self.low_bound_cdf = self.normal.cdf(low_bound_standard) + self.high_bound_cdf = self.normal.cdf(high_bound_standard) + + self.low_bound_pdf = self.normal.prob(low_bound_standard) + self.high_bound_pdf = self.normal.prob(high_bound_standard) + + return original_mean + self.normal.stddev() * ( + self.low_bound_pdf - self.high_bound_pdf + ) / (self.high_bound_cdf - self.low_bound_cdf + 1e-3) + + def _log_prob(self, value): + original_log_prob = self.normal.log_prob(value) + + return original_log_prob - tf.math.log( + self.high_bound_cdf - self.low_bound_cdf + 1e-3 + ) + + return independent_lib.Independent( + CustomCensored(normal_dist), + reinterpreted_batch_ndims=tf.size(event_shape), + validate_args=validate_args, + ) + + @staticmethod + def params_size(event_shape=(), name=None): + """The number of `params` needed to create a single distribution.""" + with tf.name_scope(name or "IndependentDoublyCensoredNormal_params_size"): + event_shape = tf.convert_to_tensor( + event_shape, name="event_shape", dtype_hint=tf.int32 + ) + return np.int32(2) * _event_size( + event_shape, name=name or "IndependentDoublyCensoredNormal_params_size" + ) + + def get_config(self): + """Returns the config of this layer. + NOTE: At the moment, this configuration can only be serialized if the + Layer's `convert_to_tensor_fn` is a serializable Keras object (i.e., + implements `get_config`) or one of the standard values: + - `Distribution.sample` (or `"sample"`) + - `Distribution.mean` (or `"mean"`) + - `Distribution.mode` (or `"mode"`) + - `Distribution.stddev` (or `"stddev"`) + - `Distribution.variance` (or `"variance"`) + """ + config = { + "event_shape": self._event_shape, + "convert_to_tensor_fn": _serialize(self._convert_to_tensor_fn), + "validate_args": self._validate_args, + } + base_config = super(IndependentDoublyCensoredNormal, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @property + def output(self): + """This allows the use of this layer with the shap package.""" + return super(IndependentDoublyCensoredNormal, self).output[0] + + +@tf.keras.saving.register_keras_serializable() +class IndependentConcaveBeta(tfpl.DistributionLambda): + """An independent 4-parameter Beta Keras layer with enforced concavity""" + + # INdependent + def __init__( + self, + event_shape=(), + convert_to_tensor_fn=tfd.Distribution.mean, + validate_args=False, + **kwargs + ): + """Initialize the `IndependentConcaveBeta` layer. + Args: + event_shape: integer vector `Tensor` representing the shape of single + draw from this distribution. + convert_to_tensor_fn: Python `callable` that takes a `tfd.Distribution` + instance and returns a `tf.Tensor`-like object. + Default value: `tfd.Distribution.mean`. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + Default value: `False`. + **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + """ + convert_to_tensor_fn = _get_convert_to_tensor_fn(convert_to_tensor_fn) + + # If there is a 'make_distribution_fn' keyword argument (e.g., because we + # are being called from a `from_config` method), remove it. We pass the + # distribution function to `DistributionLambda.__init__` below as the first + # positional argument. + kwargs.pop("make_distribution_fn", None) + + super(IndependentConcaveBeta, self).__init__( + lambda t: IndependentConcaveBeta.new(t, event_shape, validate_args), + convert_to_tensor_fn, + **kwargs + ) + + self._event_shape = event_shape + self._convert_to_tensor_fn = convert_to_tensor_fn + self._validate_args = validate_args + + @staticmethod + def new(params, event_shape=(), validate_args=False, name=None): + """Create the distribution instance from a `params` vector.""" + with tf.name_scope(name or "IndependentConcaveBeta"): + params = tf.convert_to_tensor(params, name="params") + event_shape = dist_util.expand_to_vector( + tf.convert_to_tensor( + event_shape, name="event_shape", dtype_hint=tf.int32 + ), + tensor_name="event_shape", + ) + output_shape = tf.concat( + [ + tf.shape(params)[:-1], + event_shape, + ], + axis=0, + ) + alpha, beta, shift, scale = tf.split(params, 4, axis=-1) + # alpha > 2 and beta > 2 produce a concave downward Beta + alpha = tf.math.softplus(tf.reshape(alpha, output_shape)) + 2.0 + beta = tf.math.softplus(tf.reshape(beta, output_shape)) + 2.0 + shift = tf.math.softplus(tf.reshape(shift, output_shape)) + scale = tf.math.softplus(tf.reshape(scale, output_shape)) + 1e-3 + betad = tfd.Beta(alpha, beta, validate_args=validate_args) + transf_betad = tfd.TransformedDistribution( + distribution=betad, bijector=tfb.Shift(shift)(tfb.Scale(scale)) + ) + return independent_lib.Independent( + transf_betad, + reinterpreted_batch_ndims=tf.size(event_shape), + validate_args=validate_args, + ) + + @staticmethod + def params_size(event_shape=(), name=None): + """The number of `params` needed to create a single distribution.""" + with tf.name_scope(name or "IndependentConcaveBeta_params_size"): + event_shape = tf.convert_to_tensor( + event_shape, name="event_shape", dtype_hint=tf.int32 + ) + return np.int32(4) * _event_size( + event_shape, name=name or "IndependentConcaveBeta_params_size" + ) + + def get_config(self): + """Returns the config of this layer. + NOTE: At the moment, this configuration can only be serialized if the + Layer's `convert_to_tensor_fn` is a serializable Keras object (i.e., + implements `get_config`) or one of the standard values: + - `Distribution.sample` (or `"sample"`) + - `Distribution.mean` (or `"mean"`) + - `Distribution.mode` (or `"mode"`) + - `Distribution.stddev` (or `"stddev"`) + - `Distribution.variance` (or `"variance"`) + """ + config = { + "event_shape": self._event_shape, + "convert_to_tensor_fn": _serialize(self._convert_to_tensor_fn), + "validate_args": self._validate_args, + } + base_config = super(IndependentConcaveBeta, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @property + def output(self): + """This allows the use of this layer with the shap package.""" + return super(IndependentConcaveBeta, self).output[0] @tf.keras.saving.register_keras_serializable() @@ -380,6 +757,290 @@ def output(self): return super(IndependentLogNormal, self).output[0] +@tf.keras.saving.register_keras_serializable() +class IndependentLogitNormal(tfpl.DistributionLambda): + """An independent Logit-Normal Keras layer.""" + + def __init__( + self, + event_shape=(), + convert_to_tensor_fn=tfd.Distribution.sample, + validate_args=False, + **kwargs + ): + """Initialize the `IndependentLogitNormal` layer. + Args: + event_shape: integer vector `Tensor` representing the shape of single + draw from this distribution. + convert_to_tensor_fn: Python `callable` that takes a `tfd.Distribution` + instance and returns a `tf.Tensor`-like object. + Default value: `tfd.Distribution.mean`. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + Default value: `False`. + **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + """ + convert_to_tensor_fn = _get_convert_to_tensor_fn(convert_to_tensor_fn) + + # If there is a 'make_distribution_fn' keyword argument (e.g., because we + # are being called from a `from_config` method), remove it. We pass the + # distribution function to `DistributionLambda.__init__` below as the first + # positional argument. + kwargs.pop("make_distribution_fn", None) + + super(IndependentLogitNormal, self).__init__( + lambda t: IndependentLogitNormal.new(t, event_shape, validate_args), + convert_to_tensor_fn, + **kwargs + ) + + self._event_shape = event_shape + self._convert_to_tensor_fn = convert_to_tensor_fn + self._validate_args = validate_args + + @staticmethod + def new(params, event_shape=(), validate_args=False, name=None): + """Create the distribution instance from a `params` vector.""" + with tf.name_scope(name or "IndependentLogitNormal"): + params = tf.convert_to_tensor(params, name="params") + event_shape = dist_util.expand_to_vector( + tf.convert_to_tensor( + event_shape, name="event_shape", dtype_hint=tf.int32 + ), + tensor_name="event_shape", + ) + output_shape = tf.concat( + [ + tf.shape(params)[:-1], + event_shape, + ], + axis=0, + ) + loc, scale = tf.split(params, 2, axis=-1) + return independent_lib.Independent( + tfd.LogitNormal( + loc=tf.reshape(loc, output_shape), + scale=tf.math.softplus(tf.reshape(scale, output_shape)) + 1e-3, + validate_args=validate_args, + ), + reinterpreted_batch_ndims=tf.size(event_shape), + validate_args=validate_args, + ) + + @staticmethod + def params_size(event_shape=(), name=None): + """The number of `params` needed to create a single distribution.""" + with tf.name_scope(name or "IndependentLogitNormal_params_size"): + event_shape = tf.convert_to_tensor( + event_shape, name="event_shape", dtype_hint=tf.int32 + ) + return np.int32(2) * _event_size( + event_shape, name=name or "IndependentLogitNormal_params_size" + ) + + def get_config(self): + """Returns the config of this layer. + NOTE: At the moment, this configuration can only be serialized if the + Layer's `convert_to_tensor_fn` is a serializable Keras object (i.e., + implements `get_config`) or one of the standard values: + - `Distribution.sample` (or `"sample"`) + - `Distribution.mean` (or `"mean"`) + - `Distribution.mode` (or `"mode"`) + - `Distribution.stddev` (or `"stddev"`) + - `Distribution.variance` (or `"variance"`) + """ + config = { + "event_shape": self._event_shape, + "convert_to_tensor_fn": _serialize(self._convert_to_tensor_fn), + "validate_args": self._validate_args, + } + base_config = super(IndependentLogitNormal, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @property + def output(self): + """This allows the use of this layer with the shap package.""" + return super(IndependentLogitNormal, self).output[0] + + +@tf.keras.saving.register_keras_serializable() +class IndependentMixtureNormal(tfpl.DistributionLambda): + """A mixture of two normal distributions Keras layer. + 5-parameters distribution: loc1, scale1, loc2, scale2, weight + """ + + def __init__( + self, + event_shape=(), + convert_to_tensor_fn=tfd.Distribution.mean, + validate_args=False, + **kwargs + ): + """Initialize the `IndependentMixtureNormal` layer. + Args: + event_shape: integer vector `Tensor` representing the shape of single + draw from this distribution. + convert_to_tensor_fn: Python `callable` that takes a `tfd.Distribution` + instance and returns a `tf.Tensor`-like object. + Default value: `tfd.Distribution.mean`. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + Default value: `False`. + **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. + """ + + convert_to_tensor_fn = _get_convert_to_tensor_fn(convert_to_tensor_fn) + + # If there is a 'make_distribution_fn' keyword argument (e.g., because we + # are being called from a `from_config` method), remove it. We pass the + # distribution function to `DistributionLambda.__init__` below as the first + # positional argument. + kwargs.pop("make_distribution_fn", None) + + super(IndependentMixtureNormal, self).__init__( + lambda t: IndependentMixtureNormal.new(t, event_shape, validate_args), + convert_to_tensor_fn, + **kwargs + ) + + self._event_shape = event_shape + self._convert_to_tensor_fn = convert_to_tensor_fn + self._validate_args = validate_args + + @staticmethod + def new(params, event_shape=(), validate_args=False, name=None): + """Create the distribution instance from a `params` vector.""" + with tf.name_scope(name or "IndependentMixtureNormal"): + params = tf.convert_to_tensor(params, name="params") + + event_shape = dist_util.expand_to_vector( + tf.convert_to_tensor( + event_shape, name="event_shape", dtype_hint=tf.int32 + ), + tensor_name="event_shape", + ) + + output_shape = tf.concat( + [ + tf.shape(params)[:-1], + event_shape, + ], + axis=0, + ) + + loc1, scale1, loc2, scale2, weight = tf.split(params, 5, axis=-1) + loc1 = tf.reshape(loc1, output_shape) + scale1 = tf.math.softplus(tf.reshape(scale1, output_shape)) + 1e-3 + loc2 = tf.reshape(loc2, output_shape) + scale2 = tf.math.softplus(tf.reshape(scale2, output_shape)) + 1e-3 + weight = tf.math.sigmoid(tf.reshape(weight, output_shape)) + + # Create the component distributions + normald1 = tfd.Normal(loc=loc1, scale=scale1) + normald2 = tfd.Normal(loc=loc2, scale=scale2) + + # Create a categorical distribution for the weights + cat = tfd.Categorical( + probs=tf.concat( + [tf.expand_dims(weight, -1), tf.expand_dims(1 - weight, -1)], + axis=-1, + ) + ) + + class CustomMixture(tfd.Distribution): + def __init__(self, cat, normald1, normald2): + self.cat = cat + self.normald1 = normald1 + self.normald2 = normald2 + super(CustomMixture, self).__init__( + dtype=normald1.dtype, + reparameterization_type=tfd.FULLY_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=True, + ) + + def _sample_n(self, n, seed=None): + indices = self.cat.sample(sample_shape=(n,), seed=seed) + + # Sample from both truncated normal distributions + samples1 = self.normald1.sample(sample_shape=(n,), seed=seed) + samples2 = self.normald2.sample(sample_shape=(n,), seed=seed) + + # Stack the samples along a new axis + samples = tf.stack([samples1, samples2], axis=-1) + + # Gather samples according to indices from the categorical distribution + chosen_samples = tf.gather( + samples, + indices, + batch_dims=tf.get_static_value(tf.rank(indices)), + ) + + return chosen_samples + + def _log_prob(self, value): + log_prob1 = self.normald1.log_prob(value) + log_prob2 = self.normald2.log_prob(value) + log_probs = tf.stack([log_prob1, log_prob2], axis=-1) + weighted_log_probs = log_probs + tf.math.log( + tf.concat([weight, 1 - weight], axis=-1) + ) + return tf.reduce_logsumexp(weighted_log_probs, axis=-1) + + def _mean(self): + return ( + weight * self.normald1.mean() + + (1 - weight) * self.normald2.mean() + ) + + mixtured = CustomMixture(cat, normald1, normald2) + + return independent_lib.Independent( + mixtured, + reinterpreted_batch_ndims=tf.size(event_shape), + validate_args=validate_args, + ) + + @staticmethod + def params_size(event_shape=(), name=None): + """The number of `params` needed to create a single distribution.""" + with tf.name_scope(name or "IndependentMixtureNormal_params_size"): + event_shape = tf.convert_to_tensor( + event_shape, name="event_shape", dtype_hint=tf.int32 + ) + return np.int32(5) * _event_size( + event_shape, name=name or "IndependentMixtureNormal_params_size" + ) + + def get_config(self): + """Returns the config of this layer. + NOTE: At the moment, this configuration can only be serialized if the + Layer's `convert_to_tensor_fn` is a serializable Keras object (i.e., + implements `get_config`) or one of the standard values: + - `Distribution.sample` (or `"sample"`) + - `Distribution.mean` (or `"mean"`) + - `Distribution.mode` (or `"mode"`) + - `Distribution.stddev` (or `"stddev"`) + - `Distribution.variance` (or `"variance"`) + """ + config = { + "event_shape": self._event_shape, + "convert_to_tensor_fn": _serialize(self._convert_to_tensor_fn), + "validate_args": self._validate_args, + } + base_config = super(IndependentMixtureNormal, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @property + def output(self): + """This allows the use of this layer with the shap package.""" + return super(IndependentMixtureNormal, self).output[0] + + @tf.keras.saving.register_keras_serializable() class IndependentTruncatedNormal(tfpl.DistributionLambda): """An independent TruncatedNormal Keras layer.""" diff --git a/tests/test_models.py b/tests/test_models.py index 60d0bb6..c4ca9e5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -17,7 +17,7 @@ dropout=[None, 0.1, [0.1, 0.0]], mc_dropout=[True, False], out_bias_init=["zeros", np.array([0.2]), np.array([0.2, 2.1])], - probabilistic_layer=[None] + ["IndependentNormal", "IndependentGamma"], + probabilistic_layer=[None] + ["IndependentGamma", "MultivariateNormalDiag"], skip_connection=[False, True], ) @@ -54,18 +54,55 @@ def _test_model(model): def _test_prediction(model, scenario_kwargs, dummy_input, output_size): + is_deterministic = ( + scenario_kwargs["dropout"] is None or not scenario_kwargs["mc_dropout"] + ) + is_probabilistic = scenario_kwargs["probabilistic_layer"] is not None + if is_probabilistic: + return + pred = model(dummy_input) assert pred.shape == (32, output_size) pred2 = model(dummy_input) - if scenario_kwargs["probabilistic_layer"] is not None: - pred = pred.mean() - pred2 = pred2.mean() - if scenario_kwargs["dropout"] is not None and scenario_kwargs["mc_dropout"]: + if is_deterministic: + assert_array_equal(pred, pred2) + else: with pytest.raises(AssertionError): assert_array_equal(pred, pred2) - else: - assert_array_equal(pred, pred2) + + +def _test_prediction_prob(model, scenario_kwargs, dummy_input, output_size): + is_deterministic = ( + scenario_kwargs["dropout"] is None or not scenario_kwargs["mc_dropout"] + ) + is_probabilistic = scenario_kwargs["probabilistic_layer"] is not None + if not is_probabilistic: + return + + pred1 = model(dummy_input) + assert pred1.shape == (32, output_size) + pred2 = model(dummy_input) + try: + # Idependent layers have a "distribution" attribute + pred1_params = pred1.parameters["distribution"].parameters + pred2_params = pred2.parameters["distribution"].parameters + except KeyError: + pred1_params = pred1.parameters + pred2_params = pred2.parameters + + for param in pred1_params.keys(): + try: + param_array1 = pred1_params[param].numpy() + param_array2 = pred2_params[param].numpy() + except AttributeError: + continue + + if is_deterministic: + assert_array_equal(param_array1, param_array2) + else: + with pytest.raises(AssertionError): + assert_array_equal(param_array1, param_array2) @pytest.mark.parametrize("scenario_kwargs", FCN_SCENARIOS) @@ -98,6 +135,7 @@ def test_fully_connected_network(scenario_kwargs): _test_model(model) _test_prediction(model, scenario_kwargs, dummy_input, output_size) + _test_prediction_prob(model, scenario_kwargs, dummy_input, output_size) @pytest.mark.parametrize("scenario_kwargs", FCN_SCENARIOS) @@ -130,6 +168,7 @@ def test_fully_connected_multibranch_network(scenario_kwargs): _test_model(model) _test_prediction(model, scenario_kwargs, dummy_input, output_size) + _test_prediction_prob(model, scenario_kwargs, dummy_input, output_size) @pytest.mark.parametrize("scenario_kwargs", DCN_SCENARIOS) @@ -155,3 +194,4 @@ def test_deep_cross_network(scenario_kwargs): _test_model(model) _test_prediction(model, scenario_kwargs, dummy_input, output_size) + _test_prediction_prob(model, scenario_kwargs, dummy_input, output_size) diff --git a/tests/test_probabilistic_layers.py b/tests/test_probabilistic_layers.py index 939e5f9..55ce556 100644 --- a/tests/test_probabilistic_layers.py +++ b/tests/test_probabilistic_layers.py @@ -92,8 +92,3 @@ def test_probabilistic_model_predict(layer, features_dataset, targets_dataset): assert out_samples.shape[0] == num_samples assert out_samples.shape[1] == data.y.shape[0] assert out_samples.shape[2] == data.y.shape[-1] - out_mean = out_distr.mean() - assert isinstance(out_mean, tf.Tensor) - assert out_mean.ndim == 2 - assert out_mean.shape[0] == data.y.shape[0] - assert out_mean.shape[1] == data.y.shape[-1] diff --git a/tests/test_save_model.py b/tests/test_save_model.py index 73df307..d5c1b31 100644 --- a/tests/test_save_model.py +++ b/tests/test_save_model.py @@ -16,7 +16,7 @@ def _belongs_here(obj, module): return obj[1].__module__ == module.__name__ -ALL_LAYERS = [ +ALL_PROB_LAYERS = [ obj[0] for obj in getmembers(probabilistic_layers, isclass) if _belongs_here(obj, probabilistic_layers) @@ -51,7 +51,7 @@ def _belongs_here(obj, module): @pytest.mark.parametrize("save_format", ["tf", "h5"]) @pytest.mark.parametrize("loss", TEST_LOSSES) -@pytest.mark.parametrize("prob_layer", ALL_LAYERS) +@pytest.mark.parametrize("prob_layer", ALL_PROB_LAYERS) def test_save_model(save_format, loss, prob_layer, tmp_path): """Test model save/load""" @@ -67,7 +67,7 @@ def test_save_model(save_format, loss, prob_layer, tmp_path): 2, hidden_layers=[3], probabilistic_layer=prob_layer, - mc_dropout=True, + mc_dropout=False, ) assert isinstance(model.from_config(model.get_config()), Functional) loss = get_loss(loss) @@ -102,12 +102,29 @@ def test_save_model(save_format, loss, prob_layer, tmp_path): assert completed_process.returncode == 0, "failed to reload model" input_arr = tf.random.uniform((1, 5)) - outputs = model(input_arr).mean() + pred1 = model(input_arr) del model tf.keras.backend.clear_session() model = tf.keras.saving.load_model(tmp_path, compile=False) assert isinstance(model, Functional) - np.testing.assert_allclose(model(input_arr).mean(), outputs) + + pred2 = model(input_arr) + try: + # Idependent layers have a "distribution" attribute + pred1_params = pred1.parameters["distribution"].parameters + pred2_params = pred2.parameters["distribution"].parameters + except KeyError: + pred1_params = pred1.parameters + pred2_params = pred2.parameters + + for param in pred1_params.keys(): + try: + param_array1 = pred1_params[param].numpy() + param_array2 = pred2_params[param].numpy() + except AttributeError: + continue + + np.testing.assert_allclose(param_array1, param_array2) def test_save_model_mlflow(tmp_path): diff --git a/tests/test_train.py b/tests/test_train.py index d6b99f7..228ace5 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -155,7 +155,9 @@ def test_train_fromfile(tmp_path, cfg): cfg.update({"epochs": num_epochs}) splitter_options = ValidDataSplitterOptions(time="lists", station="lists") - datasplitter = DataSplitter(splitter_options.time_split, splitter_options.station_split) + datasplitter = DataSplitter( + splitter_options.time_split, splitter_options.station_split + ) datanormalizer = DataTransformer(**["normalizer"]) batch_dims = ["forecast_reference_time", "t", "station"] datamodule = DataModule( @@ -190,7 +192,9 @@ def test_train_fromds(features_dataset, targets_dataset, cfg): cfg.update({"epochs": num_epochs}) splitter_options = ValidDataSplitterOptions(time="lists", station="lists") - datasplitter = DataSplitter(splitter_options.time_split, splitter_options.station_split) + datasplitter = DataSplitter( + splitter_options.time_split, splitter_options.station_split + ) datanormalizer = DataTransformer(**cfg["normalizer"]) batch_dims = ["forecast_reference_time", "t", "station"] datamodule = DataModule(