diff --git a/third_party/xla/xla/literal.cc b/third_party/xla/xla/literal.cc index 34d70133ab2411..4ce52706ed97a1 100644 --- a/third_party/xla/xla/literal.cc +++ b/third_party/xla/xla/literal.cc @@ -252,7 +252,7 @@ Literal::Literal(const Shape& shape) void Literal::SetShape(const Shape& shape) { Shape shape_storage; const Shape* shape_ptr = &shape; - if (LayoutUtil::HasCustomElementSizeInBits(shape)) { + if (shape.IsArray() && LayoutUtil::HasCustomElementSizeInBits(shape)) { shape_storage = shape; shape_storage.mutable_layout()->set_element_size_in_bits(0); shape_ptr = &shape_storage; diff --git a/third_party/xla/xla/literal_test.cc b/third_party/xla/xla/literal_test.cc index 36a3c263e27c36..42b4340d2ddf82 100644 --- a/third_party/xla/xla/literal_test.cc +++ b/third_party/xla/xla/literal_test.cc @@ -2583,6 +2583,14 @@ TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) { EXPECT_FALSE(c1.IsKnown()); } +TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArraysS4Tuple) { + auto inner_shape = ShapeUtil::MakeShape(S4, {4, 4}); + inner_shape.mutable_layout()->set_element_size_in_bits(4); + Literal c1 = Literal::CreateFromShapeWithUnknownLeafArrays( + ShapeUtil::MakeTupleShape({inner_shape})); + EXPECT_FALSE(c1.IsKnown()); +} + TEST_F(LiteralUtilTest, CreatePartiallyKnownTuple) { Literal c1 = Literal::CreateFromShapeWithUnknownLeafArrays( ShapeUtil::MakeShape(F32, {4, 4}));