From 5038e5893297b769b7ab7a7e985d5ada4accf323 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Mon, 29 Jan 2024 10:21:21 -0800 Subject: [PATCH] Update black formatting (#1415) --- keras_nlp/layers/modeling/transformer_decoder.py | 8 +++++--- keras_nlp/models/bloom/bloom_decoder.py | 8 +++++--- keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py | 8 +++++--- keras_nlp/models/llama/llama_decoder.py | 8 +++++--- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/keras_nlp/layers/modeling/transformer_decoder.py b/keras_nlp/layers/modeling/transformer_decoder.py index 3a3cda3f2..cb7c77933 100644 --- a/keras_nlp/layers/modeling/transformer_decoder.py +++ b/keras_nlp/layers/modeling/transformer_decoder.py @@ -469,9 +469,11 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index, + ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ), ) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/keras_nlp/models/bloom/bloom_decoder.py b/keras_nlp/models/bloom/bloom_decoder.py index b3f8b80da..c5dbb5135 100644 --- a/keras_nlp/models/bloom/bloom_decoder.py +++ b/keras_nlp/models/bloom/bloom_decoder.py @@ -174,9 +174,11 @@ def _compute_attention_mask( batch_size, input_length, output_length, - 0 - if attention_cache_update_index is None - else attention_cache_update_index, + ( + 0 + if attention_cache_update_index is None + else attention_cache_update_index + ), ) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py index ae646fb2b..f80cafd52 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py @@ -211,9 +211,11 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index, + ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ), ) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/keras_nlp/models/llama/llama_decoder.py b/keras_nlp/models/llama/llama_decoder.py index 47bac478c..2137831be 100644 --- a/keras_nlp/models/llama/llama_decoder.py +++ b/keras_nlp/models/llama/llama_decoder.py @@ -172,9 +172,11 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index, + ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ), ) return ( ops.minimum(decoder_mask, causal_mask)