diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.h b/third_party/xla/xla/hlo/ir/hlo_instructions.h index e706f80c971f33..f11e536f45e675 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.h +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.h @@ -2094,6 +2094,11 @@ class HloCustomCallInstruction : public HloCallableInstruction { CHECK(layout_constrained()); return operand_shapes_with_layout_; } + void set_operand_shapes_with_layout( + std::vector operand_shapes_with_layout) { + CHECK(layout_constrained()); + operand_shapes_with_layout_ = std::move(operand_shapes_with_layout); + } void set_custom_call_schedule(CustomCallSchedule custom_call_schedule) { custom_call_schedule_ = custom_call_schedule; }