From 5d2e5f4580026914d3edfc035260491d464934ac Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Wed, 18 Sep 2024 14:48:40 -0700 Subject: [PATCH] Fix device scope issues (#1841) We want to always place tf ops on a GPU device, this broke. --- keras_nlp/src/tokenizers/byte_pair_tokenizer.py | 1 + .../src/tokenizers/sentence_piece_tokenizer.py | 1 + .../src/tokenizers/word_piece_tokenizer.py | 1 + keras_nlp/src/utils/tensor_utils.py | 17 +++++++++-------- keras_nlp/src/utils/tensor_utils_test.py | 15 ++++++++++++++- 5 files changed, 26 insertions(+), 9 deletions(-) diff --git a/keras_nlp/src/tokenizers/byte_pair_tokenizer.py b/keras_nlp/src/tokenizers/byte_pair_tokenizer.py index 5eecf4cbf..dfee22f12 100644 --- a/keras_nlp/src/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/src/tokenizers/byte_pair_tokenizer.py @@ -540,6 +540,7 @@ def tokenize(self, inputs): if self.add_prefix_space: inputs = tf.strings.join([" ", inputs]) + inputs = tf.convert_to_tensor(inputs) unbatched = inputs.shape.rank == 0 if unbatched: inputs = tf.expand_dims(inputs, 0) diff --git a/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py b/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py index ea9f4d6f4..06594d14d 100644 --- a/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py +++ b/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py @@ -238,6 +238,7 @@ def _check_vocabulary(self): @preprocessing_function def tokenize(self, inputs): self._check_vocabulary() + inputs = tf.convert_to_tensor(inputs) unbatched = inputs.shape.rank == 0 if unbatched: inputs = tf.expand_dims(inputs, 0) diff --git a/keras_nlp/src/tokenizers/word_piece_tokenizer.py b/keras_nlp/src/tokenizers/word_piece_tokenizer.py index f228afae2..14d699262 100644 --- a/keras_nlp/src/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/src/tokenizers/word_piece_tokenizer.py @@ -473,6 +473,7 @@ def _check_vocabulary(self): @preprocessing_function def tokenize(self, inputs): self._check_vocabulary() + inputs = tf.convert_to_tensor(inputs) unbatched = inputs.shape.rank == 0 pattern = None if self.split and self.special_tokens_in_strings: diff --git a/keras_nlp/src/utils/tensor_utils.py b/keras_nlp/src/utils/tensor_utils.py index 94594202a..7502c38bc 100644 --- a/keras_nlp/src/utils/tensor_utils.py +++ b/keras_nlp/src/utils/tensor_utils.py @@ -53,20 +53,21 @@ def preprocessing_function(fn): params = inspect.signature(fn).parameters accepts_labels = all(k in params for k in ("x", "y", "sample_weight")) - with tf.device("cpu"): - if not accepts_labels: + if not accepts_labels: - @functools.wraps(fn) - def wrapper(self, x, **kwargs): + @functools.wraps(fn) + def wrapper(self, x, **kwargs): + with tf.device("cpu"): x = convert_preprocessing_inputs(x) with no_convert_scope(): x = fn(self, x, **kwargs) return convert_preprocessing_outputs(x) - else: + else: - @functools.wraps(fn) - def wrapper(self, x, y=None, sample_weight=None, **kwargs): + @functools.wraps(fn) + def wrapper(self, x, y=None, sample_weight=None, **kwargs): + with tf.device("cpu"): x, y, sample_weight = convert_preprocessing_inputs( (x, y, sample_weight) ) @@ -74,7 +75,7 @@ def wrapper(self, x, y=None, sample_weight=None, **kwargs): x = fn(self, x, y=y, sample_weight=sample_weight, **kwargs) return convert_preprocessing_outputs(x) - return wrapper + return wrapper def convert_preprocessing_inputs(x): diff --git a/keras_nlp/src/utils/tensor_utils_test.py b/keras_nlp/src/utils/tensor_utils_test.py index c0f34595c..463a26729 100644 --- a/keras_nlp/src/utils/tensor_utils_test.py +++ b/keras_nlp/src/utils/tensor_utils_test.py @@ -23,12 +23,13 @@ from keras_nlp.src.utils.tensor_utils import convert_preprocessing_outputs from keras_nlp.src.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.src.utils.tensor_utils import is_tensor_type +from keras_nlp.src.utils.tensor_utils import preprocessing_function from keras_nlp.src.utils.tensor_utils import tensor_to_list class ConvertHelpers(TestCase): def test_basics(self): - inputs = ops.array([1, 2, 3]) + inputs = [1, 2, 3] # Convert to tf. outputs = convert_preprocessing_inputs(inputs) self.assertAllEqual(outputs, ops.array(inputs)) @@ -92,6 +93,18 @@ def to_list(x): inputs = tree.flatten(tree.map_structure(to_list, inputs)) self.assertAllEqual(outputs, inputs) + def test_placement(self): + # Make sure we always place preprocessing on the CPU on all backends. + @preprocessing_function + def test(self, inputs): + for x in inputs: + if isinstance(x, tf.Tensor): + self.assertTrue("CPU" in x.device) + self.assertFalse("GPU" in x.device) + return inputs + + test(self, ([1, 2, 3], ["foo", "bar"], "foo")) + class TensorToListTest(TestCase): def test_ragged_input(self):