From d56766cb89e15c9ee5ca385bc9bf863aa5aa5047 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Fri, 20 Sep 2024 19:32:19 -0700 Subject: [PATCH] Add `SparseTensor` input validation to SparseCore conversion op. PiperOrigin-RevId: 677054398 --- .../tpu/kernels/sparse_core_preprocess_ops.cc | 31 ++++++++++++------- .../tpu/kernels/sparse_core_preprocess_ops.h | 2 +- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc index 8548a92efe0495..e25889827a49f3 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc @@ -92,7 +92,9 @@ Status ValidateInputs(const Tensor& indices_or_row_splits, const Tensor& values, // tensor. } else if (indices_or_row_splits.dims() == 2 && indices_or_row_splits.NumElements() >= 0) { - // TODO(pineapplejuice233): Add checking logic for sparse tensor input. + // NOTE(mrry): Checking logic for SparseTensor inputs is in + // `ComputeRowIdsBeforePadding()`, to avoid an extra traversal of the + // indices matrix. } else if (indices_or_row_splits.dims() == 1 && indices_or_row_splits.NumElements() > 0) { // Ragged tensor. @@ -114,6 +116,7 @@ Status ValidateInputs(const Tensor& indices_or_row_splits, const Tensor& values, Status ComputeRowIdsBeforePadding(const Tensor& indices_or_row_splits, const int32 total_id_count, + const int32 sample_count, int32* row_ids_before_padding) { // The only difference between dense tensor, sparse tensor and ragged tensor // is the row ids output. @@ -140,7 +143,14 @@ Status ComputeRowIdsBeforePadding(const Tensor& indices_or_row_splits, if (current_row_id < previous_row_id) { return absl::InvalidArgumentError( "Invalid indices_or_row_splits input, indices of SparseTensor need " - "to be sorted in ascending order."); + "to be sorted in ascending (non-decreasing) order."); + } + if (current_row_id >= sample_count) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid indices_or_row_splits input, indices of SparseTensor " + "contained a row_id ", + current_row_id, " that was >= the sample count (", sample_count, + ").")); } *(row_ids_before_padding + i) = current_row_id; previous_row_id = current_row_id; @@ -309,9 +319,9 @@ class ConvertToCooTensorOp : public OpKernel { auto row_ids_before_dedup = std::make_unique(total_id_count); - OP_REQUIRES_OK( - ctx, ComputeRowIdsBeforePadding(*indices_or_row_splits, total_id_count, - row_ids_before_dedup.get())); + OP_REQUIRES_OK(ctx, ComputeRowIdsBeforePadding( + *indices_or_row_splits, total_id_count, + sample_count_, row_ids_before_dedup.get())); // Compute the rescaled gains for non-sum combiners. std::optional> gains_rescale = @@ -520,9 +530,8 @@ void GetMinibatchesInCsrWithPhysicalReplicaOp::Compute(OpKernelContext* ctx) { "The number of minibatches per sparse core is ", num_minibatch_per_sc, ". But the max minibatches per sparse core is set to be ", max_minibatches_per_sc_, " which is smaller."))); - VLOG(2) << "GetMinibatchesInCsrWithPhysicalReplicaOp: " - << "program_key = '" << program_key << "'" - << ", table_name = '" << table_name_ << "'" + VLOG(2) << "GetMinibatchesInCsrWithPhysicalReplicaOp: " << "program_key = '" + << program_key << "'" << ", table_name = '" << table_name_ << "'" << ", max_ids = " << max_ids_per_partition << ", max_uniques = " << max_unique_ids_per_partition << ", num_minibatch_per_sc = " << num_minibatch_per_sc; @@ -1213,9 +1222,9 @@ void ConvertToListOfSparseCoreCooTensorsOp::Compute(OpKernelContext* ctx) { auto row_ids_before_dedup = std::unique_ptr( new std::remove_extent_t[total_id_count]); - OP_REQUIRES_OK( - ctx, ComputeRowIdsBeforePadding(*indices_or_row_splits, total_id_count, - row_ids_before_dedup.get())); + OP_REQUIRES_OK(ctx, ComputeRowIdsBeforePadding(*indices_or_row_splits, + total_id_count, sample_count_, + row_ids_before_dedup.get())); // Compute the rescaled gains for non-sum combiners. std::optional> gains_rescale = diff --git a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h index ce43521cbc5147..d3651d04de2d6e 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h +++ b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h @@ -55,7 +55,7 @@ Status ValidateInputs(const Tensor& indices_or_row_splits, const Tensor& values, // Compute the row id list before padding. Status ComputeRowIdsBeforePadding(const Tensor& indices_or_row_splits, - int32 total_id_count, + int32 total_id_count, int32 sample_count, int32* row_ids_before_padding); class GetMinibatchesInCsrWithPhysicalReplicaOp : public OpKernel {