Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bayesian_neural_networks failed with Mixed Precision enabled #1860

Open
LifengWang opened this issue May 13, 2024 · 3 comments
Open

bayesian_neural_networks failed with Mixed Precision enabled #1860

LifengWang opened this issue May 13, 2024 · 3 comments

Comments

@LifengWang
Copy link

Issue Type

Bug

Source

binary

Keras Version

2.16.0

Custom Code

No

OS Platform and Distribution

Linux Ubuntu 20.04

Python version

3.10

GPU model and memory

No response

Current Behavior?

bayesian_neural_networks example works well with the default fp32 data type when using the legacy keras.

import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"

While when I enable the mixed precision for bfloat16 or float16 with the following code. bayesian_neural_networks example failed.

import tensorflow as tf
tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")

Standalone code to reproduce the issue or tutorial link

Just add the following code snippet at the beginning of this code example https://github.com/keras-team/keras-io/blob/master/examples/keras_recipes/bayesian_neural_networks.py.

import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"
import tensorflow as tf
tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")

Relevant log output

Traceback (most recent call last):
  File "/root/lifeng/keras-io/examples/keras_recipes/bayesian_neural_networks.py", line 304, in <module>
    bnn_model_small = create_bnn_model(train_sample_size)
  File "/root/lifeng/keras-io/examples/keras_recipes/bayesian_neural_networks.py", line 274, in create_bnn_model
    features = tfp.layers.DenseVariational(
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tf_keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/layers/dense_variational_v2.py", line 123, in call
    self.add_loss(self._kl_divergence_fn(q, r))
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/layers/dense_variational_v2.py", line 187, in _fn
    kl = kl_divergence_fn(distribution_a, distribution_b)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/layers/dense_variational_v2.py", line 180, in kl_divergence_fn
    distribution_a.log_prob(z) - distribution_b.log_prob(z),
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1287, in log_prob
    return self._call_log_prob(value, name, **kwargs)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1269, in _call_log_prob
    return self._log_prob(value, **kwargs)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/layers/internal/distribution_tensor_coercible.py", line 114, in _log_prob
    return self.tensor_distribution._log_prob(value, **kwargs)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/internal/distribution_util.py", line 1350, in _fn
    return fn(*args, **kwargs)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/distributions/mvn_linear_operator.py", line 243, in _log_prob
    return super(MultivariateNormalLinearOperator, self)._log_prob(x)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py", line 364, in _log_prob
    log_prob, _ = self.experimental_local_measure(
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py", line 611, in experimental_local_measure
    x = self.bijector.inverse(y, **bijector_kwargs)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/bijectors/bijector.py", line 1389, in inverse
    return self._call_inverse(y, name, **kwargs)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/bijectors/bijector.py", line 1362, in _call_inverse
    y = nest_util.convert_to_nested_tensor(
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/internal/nest_util.py", line 503, in convert_to_nested_tensor
    return convert_fn((), value, dtype, dtype_hint, name=name)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/internal/nest_util.py", line 495, in convert_fn
    return tf.convert_to_tensor(value, dtype, dtype_hint, name=name)
ValueError: Exception encountered when calling layer "dense_variational" (type DenseVariational).

y: Tensor conversion requested dtype float32 for Tensor with dtype bfloat16: <tf.Tensor 'dense_variational/sequential/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/chain_of_shift_of_scale_matvec_linear_operator/forward/shift/forward/add:0' shape=(96,) dtype=bfloat16>

Call arguments received by layer "dense_variational" (type DenseVariational):
  • inputs=tf.Tensor(shape=(None, 11), dtype=bfloat16)
@LifengWang
Copy link
Author

The standard neural network works well with mixed precision but the Bayesian neural network failed.

@chunduriv
Copy link
Collaborator

@LifengWang,

Thanks for reporting the issue. I have reproduced the behavior, please see gist for reference.

@mattdangerw
Copy link
Member

I'm not sure if in general we expect all examples to run under mixed precision without error. That will often depend on the custom layers/models created by the author in a given example.

But if anyone has interest in diving in here and fixing this up, contributions are welcome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants