Skip to content

Commit

Permalink
String Tensor SplitToSequence fix (microsoft#19942)
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp authored and Ted Themistokleous committed May 7, 2024
1 parent 4309561 commit 37634b0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/sequence/sequence_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu
int num_remaining_splits = 0;
InlinedVector<int64_t> split_sizes;
const bool is_string_type = input.IsDataTypeString();
const size_t element_size = (is_string_type) ? 0U : input.DataType()->Size();
const size_t element_size = input.DataType()->Size();

// figure out split_scalar or split_sizes
if (p_split_input) {
Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,19 @@ TEST(SequenceOpsTest, SplitToSequence_PositiveAxisScalarSplit) {
test.Run();
}

TEST(SequenceOpsTest, SplitToSequence_StringSplit) {
OpTester test("SplitToSequence", 11);
test.AddInput<std::string>("input", {3}, std::vector<std::string>({"Test string", "Another string", "A third and much longer string"}));
int64_t axis = 0;
test.AddAttribute("axis", axis);
SeqTensors<std::string> output;
output.AddTensor({1}, {"Test string"});
output.AddTensor({1}, {"Another string"});
output.AddTensor({1}, {"A third and much longer string"});
test.AddSeqOutput("S2", output);
test.Run();
}

TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0UnevenSplitFloat) {
OpTester test("SplitToSequence", 11);
test.AddInput<float>("input", {5, 2}, GetConsecutiveVector<float>(1.f, 10));
Expand Down

0 comments on commit 37634b0

Please sign in to comment.