You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
lstm_seq2seq.py works well with the default fp32 data type when using the legacy keras. import os os.environ["TF_USE_LEGACY_KERAS"] = "1"
Training completed with mixed precision successfully, but inference failed: Input 'y' of 'AddV2' Op has type float32 that does not match type bfloat16 of argument 'x'. import tensorflow as tf tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")
Standalone code to reproduce the issue or tutorial link
Just add the following code snippet at the beginning of this code example https://github.com/keras-team/keras-io/blob/master/examples/nlp/lstm_seq2seq.py.
import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"
import tensorflow as tf
tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")
Relevant log output
Traceback (most recent call last):
File "examples/nlp/lstm_seq2seq.py", line 332, in<module>
decoded_sentence = decode_sequence(input_seq)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "examples/nlp/lstm_seq2seq.py", line 297, in decode_sequence
output_tokens, h, c = decoder_model.predict([target_seq] + states_value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/tmp/__autograph_generated_filehtao8ahn.py", line 15, in tf__predict_function
retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
^^^^^
File "/tmp/__autograph_generated_fileg32txcku.py", line 45, in tf__step_function
outputs = ag__.converted_call(ag__.ld(model).distribute_strategy.run, (ag__.ld(run_step),), dict(args=(ag__.ld(data),)), fscope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_fileg32txcku.py", line 18, in run_step
outputs = ag__.converted_call(ag__.ld(model).predict_step, (ag__.ld(data),), None, fscope_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_file8lg3jru0.py", line 32, in tf__predict_step
retval_ = ag__.converted_call(ag__.ld(self), (ag__.ld(x),), dict(training=False), fscope)
^^^^^
File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 44, in tf__error_handler
ag__.if_stmt(ag__.not_(ag__.converted_call(ag__.ld(tf).debugging.is_traceback_filtering_enabled, (), None, fscope)), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 40, in else_body
raise ag__.converted_call(ag__.ld(e).with_traceback, (ag__.ld(filtered_tb),), None, fscope) from None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 34, in else_body
retval_ = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_filetingiv6p.py", line 67, in tf____call__
retval_ = ag__.converted_call(ag__.converted_call(ag__.ld(super), (), None, fscope).__call__, tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
^^^^^
File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 44, in tf__error_handler
ag__.if_stmt(ag__.not_(ag__.converted_call(ag__.ld(tf).debugging.is_traceback_filtering_enabled, (), None, fscope)), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 40, in else_body
raise ag__.converted_call(ag__.ld(e).with_traceback, (ag__.ld(filtered_tb),), None, fscope) from None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 34, in else_body
retval_ = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_filedgxei06x.py", line 242, in tf____call__
ag__.if_stmt(ag__.converted_call(ag__.ld(_in_functional_construction_mode), (ag__.ld(self), ag__.ld(inputs), ag__.ld(args), ag__.ld(kwargs), ag__.ld(input_list)), None, fscope), if_body_11, else_body_11, get_state_11, set_state_11, ('do_return', "kwargs['mask']", 'retval_', 'args', 'input_list', 'inputs', 'kwargs'), 3)
File "/tmp/__autograph_generated_filedgxei06x.py", line 187, in else_body_11
outputs = ag__.converted_call(ag__.ld(call_fn), (ag__.ld(inputs),) + tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_file_fd34cvd.py", line 51, in error_handler
ag__.if_stmt(ag__.converted_call(ag__.ld(hasattr), (ag__.ld(e), '_keras_call_info_injected'), None, fscope_1), if_body, else_body, get_state, set_state, (), 0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_file_fd34cvd.py", line 47, in if_body
raise ag__.ld(e)
File "/tmp/__autograph_generated_file_fd34cvd.py", line 34, in error_handler
retval__1 = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_fileqmtratyp.py", line 29, in tf__call
retval_ = ag__.converted_call(ag__.ld(self)._run_internal_graph, (ag__.ld(inputs),), dict(training=ag__.ld(training), mask=ag__.ld(mask)), fscope)
^^^^^
File "/tmp/__autograph_generated_file0lcomwvu.py", line 174, in tf___run_internal_graph
ag__.for_stmt(ag__.ld(depth_keys), None, loop_body_4, get_state_9, set_state_9, (), {'iterate_names': 'depth'})
File "/tmp/__autograph_generated_file0lcomwvu.py", line 166, in loop_body_4
ag__.for_stmt(ag__.ld(nodes), None, loop_body_3, get_state_8, set_state_8, (), {'iterate_names': 'node'})
File "/tmp/__autograph_generated_file0lcomwvu.py", line 165, in loop_body_3
ag__.if_stmt(ag__.not_(continue__3), if_body_4, else_body_4, get_state_7, set_state_7, ('continue__3',), 0)
File "/tmp/__autograph_generated_file0lcomwvu.py", line 160, in if_body_4
ag__.if_stmt(ag__.not_(continue__3), if_body_3, else_body_3, get_state_6, set_state_6, (), 0)
File "/tmp/__autograph_generated_file0lcomwvu.py", line 145, in if_body_3
outputs = ag__.converted_call(ag__.ld(node).layer, tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_file9p2pb1je.py", line 184, in tf____call__
ag__.if_stmt(ag__.and_(lambda: ag__.ld(initial_state) is None, lambda: ag__.ld(constants) is None), if_body_7, else_body_7, get_state_8, set_state_8, ('do_return', "kwargs['constants']", "kwargs['initial_state']", 'retval_', 'self._num_constants', 'self.constants_spec', 'self.input_spec', 'self.state_spec'), 8)
File "/tmp/__autograph_generated_file9p2pb1je.py", line 175, in else_body_7
ag__.if_stmt(ag__.ld(is_keras_tensor), if_body_6, else_body_6, get_state_7, set_state_7, ('do_return', "kwargs['constants']", "kwargs['initial_state']", 'retval_', 'self.input_spec'), 5)
File "/tmp/__autograph_generated_file9p2pb1je.py", line 168, in else_body_6
retval_ = ag__.converted_call(ag__.converted_call(ag__.ld(super), (), None, fscope).__call__, (ag__.ld(inputs),), dict(**ag__.ld(kwargs)), fscope)
^^^^^
File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 44, in tf__error_handler
ag__.if_stmt(ag__.not_(ag__.converted_call(ag__.ld(tf).debugging.is_traceback_filtering_enabled, (), None, fscope)), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 40, in else_body
raise ag__.converted_call(ag__.ld(e).with_traceback, (ag__.ld(filtered_tb),), None, fscope) from None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_fileqr_0kpwh.py", line 34, in else_body
retval_ = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_filedgxei06x.py", line 242, in tf____call__
ag__.if_stmt(ag__.converted_call(ag__.ld(_in_functional_construction_mode), (ag__.ld(self), ag__.ld(inputs), ag__.ld(args), ag__.ld(kwargs), ag__.ld(input_list)), None, fscope), if_body_11, else_body_11, get_state_11, set_state_11, ('do_return', "kwargs['mask']", 'retval_', 'args', 'input_list', 'inputs', 'kwargs'), 3)
File "/tmp/__autograph_generated_filedgxei06x.py", line 187, in else_body_11
outputs = ag__.converted_call(ag__.ld(call_fn), (ag__.ld(inputs),) + tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_file_fd34cvd.py", line 162, in error_handler
raise ag__.converted_call(ag__.ld(new_e).with_traceback, (ag__.ld(e).__traceback__,), None, fscope_1) from None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_file_fd34cvd.py", line 34, in error_handler
retval__1 = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_filew_gt51m7.py", line 169, in tf__call
ag__.if_stmt(ag__.not_(ag__.ld(self)._could_use_gpu_kernel), if_body_5, else_body_5, get_state_5, set_state_5, ('kwargs', 'last_output', 'outputs', 'runtime', 'states', 'inputs'), 5)
File "/tmp/__autograph_generated_filew_gt51m7.py", line 153, in else_body_5
ag__.if_stmt(ag__.converted_call(ag__.ld(gru_lstm_utils).use_new_gru_lstm_impl, (), None, fscope), if_body_4, else_body_4, get_state_4, set_state_4, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
File "/tmp/__autograph_generated_filew_gt51m7.py", line 142, in else_body_4
ag__.if_stmt(ag__.converted_call(ag__.ld(tf).executing_eagerly, (), None, fscope), if_body_3, else_body_3, get_state_3, set_state_3, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
File "/tmp/__autograph_generated_filew_gt51m7.py", line 134, in else_body_3
last_output, outputs, new_h, new_c, runtime = ag__.converted_call(ag__.ld(lstm_with_backend_selection), (), dict(**ag__.ld(normal_lstm_kwargs)), fscope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_filecam0bgs0.py", line 118, in tf__lstm_with_backend_selection
ag__.if_stmt(ag__.converted_call(ag__.ld(gru_lstm_utils).use_new_gru_lstm_impl, (), None, fscope), if_body, else_body, get_state, set_state, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
File "/tmp/__autograph_generated_filecam0bgs0.py", line 107, in else_body
last_output, outputs, new_h, new_c, runtime = ag__.converted_call(ag__.ld(defun_standard_lstm), (), dict(**ag__.ld(params)), fscope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: in user code:
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 2436, in predict_function *return step_function(self, iterator)
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 2409, in run_step *
outputs = model.predict_step(data)
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 2377, in predict_step *return self(x, training=False)
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 565, in error_handler *
del filtered_tb
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 588, in __call__ *returnsuper().__call__(*args, **kwargs)
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 565, in error_handler *
del filtered_tb
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/base_layer.py", line 1136, in __call__ *
outputs = call_fn(inputs, *args, **kwargs)
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/functional.py", line 514, in call *return self._run_internal_graph(inputs, training=training, mask=mask)
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/functional.py", line 671, in _run_internal_graph *
outputs = node.layer(*args, **kwargs)
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/layers/rnn/base_rnn.py", line 627, in __call__ *returnsuper().__call__(inputs, **kwargs)
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/training.py", line 560, in error_handler *
filtered_tb = _process_traceback_frames(e.__traceback__)
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/engine/base_layer.py", line 1136, in __call__ *
outputs = call_fn(inputs, *args, **kwargs)
File "/tmp/__autograph_generated_file_fd34cvd.py", line 162, in error_handler **
raise ag__.converted_call(ag__.ld(new_e).with_traceback, (ag__.ld(e).__traceback__,), None, fscope_1) from None
File "/tmp/__autograph_generated_file_fd34cvd.py", line 34, in error_handler
retval__1 = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope_1)
File "/tmp/__autograph_generated_filew_gt51m7.py", line 169, in tf__call **
ag__.if_stmt(ag__.not_(ag__.ld(self)._could_use_gpu_kernel), if_body_5, else_body_5, get_state_5, set_state_5, ('kwargs', 'last_output', 'outputs', 'runtime', 'states', 'inputs'), 5)
File "/tmp/__autograph_generated_filew_gt51m7.py", line 153, in else_body_5
ag__.if_stmt(ag__.converted_call(ag__.ld(gru_lstm_utils).use_new_gru_lstm_impl, (), None, fscope), if_body_4, else_body_4, get_state_4, set_state_4, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
File "/tmp/__autograph_generated_filew_gt51m7.py", line 142, in else_body_4
ag__.if_stmt(ag__.converted_call(ag__.ld(tf).executing_eagerly, (), None, fscope), if_body_3, else_body_3, get_state_3, set_state_3, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
File "/tmp/__autograph_generated_filew_gt51m7.py", line 134, in else_body_3
last_output, outputs, new_h, new_c, runtime = ag__.converted_call(ag__.ld(lstm_with_backend_selection), (), dict(**ag__.ld(normal_lstm_kwargs)), fscope)
File "/tmp/__autograph_generated_filecam0bgs0.py", line 118, in tf__lstm_with_backend_selection **ag__.if_stmt(ag__.converted_call(ag__.ld(gru_lstm_utils).use_new_gru_lstm_impl, (), None, fscope), if_body, else_body, get_state, set_state, ('last_output', 'new_c', 'new_h', 'outputs', 'runtime'), 5)
File "/tmp/__autograph_generated_filecam0bgs0.py", line 107, in else_body
last_output, outputs, new_h, new_c, runtime = ag__.converted_call(ag__.ld(defun_standard_lstm), (), dict(**ag__.ld(params)), fscope)
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/layers/rnn/lstm.py", line 983, in standard_lstm
last_output, outputs, new_states = backend.rnn(
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/backend.py", line 4985, in rnn
output_time_zero, _ = step_function(
File "/home/sdp/miniconda3/envs/my_conda_env/lib/python3.11/site-packages/tf_keras/src/layers/rnn/lstm.py", line 970, in step
z += backend.dot(h_tm1, recurrent_kernel)
TypeError: Exception encountered when calling layer 'lstm_1' (type LSTM).
Input 'y' of 'AddV2' Op has type float32 that does not match type bfloat16 of argument 'x'.
Call arguments received by layer 'lstm_1' (type LSTM):
• inputs=tf.Tensor(shape=(None, 1, 91), dtype=bfloat16)
• mask=None
• training=False
• initial_state=['tf.Tensor(shape=(None, 256), dtype=float32)', 'tf.Tensor(shape=(None, 256), dtype=float32)']
The text was updated successfully, but these errors were encountered:
Hi @joshuayao - just to confirm, this does not affect current Keras (Keras3)? Would you be able to reproduce without the legacy flag but using mixed precision?
Issue Type
Bug
Source
binary
Keras Version
2.16.0
Custom Code
No
OS Platform and Distribution
No response
Python version
3.11
GPU model and memory
No response
Current Behavior?
lstm_seq2seq.py works well with the default fp32 data type when using the legacy keras.
import os os.environ["TF_USE_LEGACY_KERAS"] = "1"
Training completed with mixed precision successfully, but inference failed: Input 'y' of 'AddV2' Op has type float32 that does not match type bfloat16 of argument 'x'.
import tensorflow as tf tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")
Standalone code to reproduce the issue or tutorial link
Relevant log output
The text was updated successfully, but these errors were encountered: