diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc index 29117bcfcaa82e..31910351d17c85 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc @@ -1582,15 +1582,28 @@ Status HloEvaluator::HandleTuple(const HloInstruction* tuple) { } } + // Inline part of LiteralUtil::MakeTuple() that avoids creating the leaf + // buffers; these buffers can be extremely large. + std::vector element_shapes; + element_shapes.reserve(operand_literals.size()); + for (const auto* element : operand_literals) { + element_shapes.push_back(&element->shape()); + } + Literal new_result = Literal::CreateFromShapeWithUndeterminedLeafArrays( + ShapeUtil::MakeTupleShapeWithPtrs(element_shapes)); + for (int i = 0, end = operand_literals.size(); i < end; ++i) { + TF_RETURN_IF_ERROR( + new_result.CopyFrom(*operand_literals[i], /*dest_shape_index=*/{i})); + } + if (evaluated_.contains(tuple)) { - Literal new_result = LiteralUtil::MakeTuple(operand_literals); CHECK(new_result.IsDetermined(visitor_shape_index_)); TF_RETURN_IF_ERROR( - evaluated_[tuple].CopyFrom(new_result, + evaluated_[tuple].CopyFrom(std::move(new_result), /*dest_shape_index=*/visitor_shape_index_, /*src_shape_index=*/visitor_shape_index_)); } else { - evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals); + evaluated_[tuple] = std::move(new_result); } return OkStatus(); }