From 43a5d4f5d8a067f3d4e2343d7ab757223fd92ae3 Mon Sep 17 00:00:00 2001 From: TJ Xu Date: Tue, 25 Jun 2024 17:38:55 -0700 Subject: [PATCH 1/4] PR #13813: [NVIDIA GPU] Assign a fixed index for cached activation Imported from GitHub PR https://github.com/openxla/xla/pull/13813 gpu_windowed_einsum_handler pass has been re-using the empty buffer of the transformed while loop. This buffer is given by the spmd dot_handler pass. The shape of the buffer has changed from the allgathered shape of the sharded operand to the output shape of the dot which leads to a shape incompatibility error. To make the gpu handler completely safe, we will make a new element in the tuple to host the cached activation with the desired shape. The slice index of where to write the slice into the full buffer also changes based on whether it's contracting or non-contracting dim is sharded. With the new element, we will need to determine the slice index ourselves in the handler pass. Copybara import of the project: -- ceeff8e5da8ecb3f382bbd8dee83e2f0c909b22d by TJ Xu : Assign a fixed index for cached activation Cache correct activation slice when contracting dim is sharded -- 233763b8efb4ab0045eb998b437c7b28c8f776c8 by TJ Xu : Simplified logic in gpu einsum handler to be more generic -- 2220cd1a022ad519cd23ab36c31c70c9627fc76d by TJ Xu : remove un-used variables Merging this change closes #13813 PiperOrigin-RevId: 646666635 --- .../gpu/gpu_windowed_einsum_handler.cc | 149 +++++++++++++++--- .../gpu/gpu_windowed_einsum_handler_test.cc | 86 +++++++++- 2 files changed, 206 insertions(+), 29 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc index ce1dfa4f1c7863..8f5e26124f24a4 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc +++ b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc @@ -378,8 +378,16 @@ absl::StatusOr HandleAgWindowedEinsumLoop(HloComputation* comp, return changed; } +static int64_t GetAgActivationCacheIndex(const HloInstruction* while_loop) { + const HloInstruction* loop_tuple = while_loop->operand(0); + const Shape& tuple_shape = loop_tuple->shape(); + CHECK(tuple_shape.IsTuple()); + return tuple_shape.tuple_shapes_size(); +} + absl::Status ProcessWindowedEinsumLoopForActivationCaching( - GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop) { + GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop, + HloInstruction* ag_with_shared_operand) { HloInstruction* loop = ag_loop.loop; // Transform the while body to cache the allgathered result in the // output buffer to be consumed by the dot @@ -392,15 +400,61 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( } // Get the output operand of the full buffer. HloInstruction* root = while_body->root_instruction(); + // Change loop body to include the new input and output element. + HloInstruction* input_tuple = while_body->parameter_instruction(0); + const Shape& input_shape = input_tuple->shape(); // The full buffer that we will use to cache the accumulated activation - // is the 4th operand in the output tuple. - int64_t full_cache_buffer_index = 3; + // is the last operand in the output tuple. + int64_t full_cache_buffer_index = GetAgActivationCacheIndex(loop); + std::vector new_input_shapes(input_shape.tuple_shapes().begin(), + input_shape.tuple_shapes().end()); + new_input_shapes.push_back(ag_with_shared_operand->shape()); + // Update body input shape + Shape new_input_shape = ShapeUtil::MakeTupleShape(new_input_shapes); + *input_tuple->mutable_shape() = new_input_shape; HloInstruction* full_buffer_output_gte = - root->mutable_operand(full_cache_buffer_index); - HloInstruction* new_full_buffer_output; + while_body->AddInstruction(HloInstruction::CreateGetTupleElement( + ag_with_shared_operand->shape(), input_tuple, + full_cache_buffer_index)); + + // Update condition input shape + HloComputation* cond_comp = loop->while_condition(); + HloInstruction* cond_input_tuple = cond_comp->parameter_instruction(0); + *cond_input_tuple->mutable_shape() = new_input_shape; + + // Update input to the while instruction in parent computation + HloInstruction* original_while_input = loop->mutable_operand(0); + HloComputation* parent_comp = loop->parent(); + std::vector new_operands( + original_while_input->operands().begin(), + original_while_input->operands().end()); + new_operands.push_back( + parent_comp->AddInstruction(HloInstruction::CreateBroadcast( + ag_with_shared_operand->shape(), + parent_comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(new_input_shapes[0].element_type()))), + {}))); + HloInstruction* new_while_input = + parent_comp->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR( + loop->ReplaceOperandWithDifferentShape(0, new_while_input)); + TF_RETURN_IF_ERROR(parent_comp->ReplaceInstructionWithDifferentShape( + original_while_input, new_while_input)); + *loop->mutable_shape() = new_input_shape; + + HloInstruction* new_full_buffer_output = nullptr; // Find the DUS in the loop body and re-use the slice indices // This should just be a constant(0) HloInstruction* dus_boundary_constant; + // The slice we need this time is the output of the first + // collective-permute + HloInstruction* first_cp_output; + for (HloInstruction* gte_user : input_gte->users()) { + if (gte_user->opcode() == HloOpcode::kCollectivePermute) { + first_cp_output = gte_user; + break; + } + } for (HloInstruction* inst : while_body->MakeInstructionPostOrder()) { HloInstruction* slice_indices; // If we have a DUS(PARAM,DS) pattern, we need to update the output @@ -434,24 +488,68 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( dus_boundary_constant->shape(), slice_indices)); VLOG(5) << "Created slice op for second slice: " << slice_indices->ToString(); - // The slice we need this time is the output of the first - // collective-permute - HloInstruction* cp_output; - for (HloInstruction* gte_user : input_gte->users()) { - if (gte_user->opcode() == HloOpcode::kCollectivePermute) { - cp_output = gte_user; - break; - } - } new_full_buffer_output = while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( full_buffer_output_gte->shape(), full_buffer_output_gte, - cp_output, + first_cp_output, {dus_boundary_constant, slice_indices, dus_boundary_constant})); } + + // If we have a Dot(DS(parameter_index1)), then operands are sharded along + // the contracting dim. Slice indices will be the contracting dim's slices. + HloInstruction* slice_index; + HloInstruction* ds_index_constant; + HloInstruction* remainder; + HloInstruction* ds_param; + // There will be 2 dynamic-slices for unrolled loops, match for each one to + // get the slice index which will be used to write the corresponding + // received shard into cached activation buffer. For unrolled loops, we need + // to write to the final buffer twice per iteration, so we need to match for + // the correct slice index based on each DS. + if (Match(inst, m::Dot(m::Op(), m::DynamicSlice(&ds_param))) && + Match(ds_param->operand(0), m::GetTupleElement(m::Parameter(), 1))) { + for (int64_t ds_op_i = 1; ds_op_i < ds_param->operands().size(); + ds_op_i++) { + if (!Match( + ds_param->mutable_operand(ds_op_i), + m::Reshape(&slice_index, m::DynamicSlice(m::Constant(), + m::Op(&remainder)))) && + !Match(ds_param->mutable_operand(ds_op_i), + m::Constant(&ds_index_constant))) { + return absl::OkStatus(); + } + } + // First DS has slice index calculated based on loop iterator + // Remainder(add(gte, partition_id)) + if (Match(remainder, + m::Remainder(m::Add(m::GetTupleElement(), m::Op()), m::Op()))) { + full_buffer_output_gte = + while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + full_buffer_output_gte->shape(), full_buffer_output_gte, + input_gte, + {ds_index_constant, ds_index_constant, slice_index})); + } + // Second DS has slice index calculated based on loop iterator+1 hence + // Remainder(add(add(gte, 1), partition_id)) + if (Match(remainder, + m::Remainder( + m::Add(m::Add(m::GetTupleElement(), m::Op()), m::Op()), + m::Op()))) { + new_full_buffer_output = + while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + full_buffer_output_gte->shape(), full_buffer_output_gte, + first_cp_output, + {ds_index_constant, ds_index_constant, slice_index})); + } + } } - TF_RETURN_IF_ERROR(root->ReplaceOperandWith(full_cache_buffer_index, - new_full_buffer_output)); + std::vector original_operands(root->operands().begin(), + root->operands().end()); + original_operands.push_back(new_full_buffer_output); + HloInstruction* new_output_tuple = while_body->AddInstruction( + HloInstruction::CreateTuple(original_operands)); + TF_RETURN_IF_ERROR( + while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple)); return absl::OkStatus(); } @@ -620,17 +718,20 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { VLOG(5) << "Found all-gather that shares the same operand with a " "windowed einsum loop : " << loop->ToString(); + + if (!ag_loop.consumed) { + TF_RETURN_IF_ERROR(ProcessWindowedEinsumLoopForActivationCaching( + ag_loop, ag_with_shared_operand)); + ag_loop.consumed = true; + } int64_t cache_output_index = dot->operand_index(ag_with_shared_operand); - HloInstruction* new_gte = comp->AddInstruction( - HloInstruction::CreateGetTupleElement(loop, 3)); + HloComputation* comp = dot->parent(); + HloInstruction* new_gte = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + loop, GetAgActivationCacheIndex(loop) - 1)); TF_RETURN_IF_ERROR( dot->ReplaceOperandWith(cache_output_index, new_gte)); TF_RETURN_IF_ERROR(comp->RemoveInstruction(ag_with_shared_operand)); - if (!ag_loop.consumed) { - TF_RETURN_IF_ERROR( - ProcessWindowedEinsumLoopForActivationCaching(ag_loop)); - ag_loop.consumed = true; - } } } // Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc index 23257e1c71a34b..6f23319980e90c 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc @@ -269,23 +269,22 @@ ENTRY main.12_spmd { FindInstructionByName(module->entry_computation(), "dot.7"); // dot.7 should now consume output of the windowed einsum while loop. EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement); - EXPECT_EQ(inst->operand(0)->tuple_index(), 3); + EXPECT_EQ(inst->operand(0)->tuple_index(), 5); EXPECT_EQ(inst->operand(0)->operand(0), ag_loop); // while loop's root should now have a chain of DUS. HloInstruction* ag_while_root = ag_loop->while_body()->root_instruction(); EXPECT_THAT(ag_while_root, GmockMatch(m::Tuple( - m::Op(), m::Op(), m::Op(), + m::Op(), m::Op(), m::Op(), m::Op(), m::Op(), m::DynamicUpdateSlice( m::DynamicUpdateSlice( m::GetTupleElement(m::Parameter()) .WithPredicate([](const HloInstruction* instr) { - return instr->tuple_index() == 3; + return instr->tuple_index() == 5; }), m::Op(), m::Op(), m::Op(), m::Op()), - m::Op(), m::Op(), m::Op(), m::Op()), - m::Op()))); + m::Op(), m::Op(), m::Op(), m::Op())))); } TEST_F(GpuWindowedEinsumHanlderTest, A2aGemmHaveStreamIds) { constexpr absl::string_view kHloString = R"( @@ -838,5 +837,82 @@ ENTRY main.9_spmd { )"); } +TEST_F(GpuWindowedEinsumHanlderTest, + AgLoopsMultipleConsumersAreChainedWithShardedContratingDim) { + constexpr absl::string_view kHloString = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0})->bf16[4096,6288]{1,0}}, num_partitions=8 + +windowed_dot_general_body_ag { + param.195 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0) + get-tuple-element.588 = bf16[16,2048,512]{2,1,0} get-tuple-element(param.195), index=0 + collective-permute.194 = bf16[16,2048,512]{2,1,0} collective-permute(get-tuple-element.588), channel_id=446, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}} + collective-permute.195 = bf16[16,2048,512]{2,1,0} collective-permute(collective-permute.194), channel_id=447, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}} + get-tuple-element.589 = bf16[4096,6288]{1,0} get-tuple-element(param.195), index=1 + get-tuple-element.590 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=2 + constant.11432 = s32[8]{0} constant({0, 512, 1024, 1536, 2048, 2560, 3072, 3584}) + get-tuple-element.592 = u32[] get-tuple-element(param.195), index=4 + partition-id.194 = u32[] partition-id() + add.4309 = u32[] add(get-tuple-element.592, partition-id.194) + constant.11431 = u32[] constant(8) + remainder.194 = u32[] remainder(add.4309, constant.11431) + dynamic-slice.388 = s32[1]{0} dynamic-slice(constant.11432, remainder.194), dynamic_slice_sizes={1} + reshape.12959 = s32[] reshape(dynamic-slice.388) + constant.11433 = s32[] constant(0) + dynamic-slice.389 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12959, constant.11433), dynamic_slice_sizes={512,6288} + dot.244 = bf16[16,2048,6288]{2,1,0} dot(get-tuple-element.588, dynamic-slice.389), lhs_contracting_dims={2}, rhs_contracting_dims={0} + add.4310 = bf16[16,2048,6288]{2,1,0} add(get-tuple-element.590, dot.244) + constant.11434 = u32[] constant(1) + add.4312 = u32[] add(get-tuple-element.592, constant.11434) + add.4313 = u32[] add(add.4312, partition-id.194) + remainder.195 = u32[] remainder(add.4313, constant.11431) + dynamic-slice.390 = s32[1]{0} dynamic-slice(constant.11432, remainder.195), dynamic_slice_sizes={1} + reshape.12960 = s32[] reshape(dynamic-slice.390) + dynamic-slice.391 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12960, constant.11433), dynamic_slice_sizes={512,6288} + dot.245 = bf16[16,2048,6288]{2,1,0} dot(collective-permute.194, dynamic-slice.391), lhs_contracting_dims={2}, rhs_contracting_dims={0} + add.4314 = bf16[16,2048,6288]{2,1,0} add(add.4310, dot.245) + get-tuple-element.591 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=3 + add.4315 = u32[] add(add.4312, constant.11434) + ROOT tuple.98 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(collective-permute.195, get-tuple-element.589, add.4314, get-tuple-element.591, add.4315) +} // windowed_dot_general_body_ag + +windowed_dot_general_cond_ag { + param = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0) + get-tuple-element = u32[] get-tuple-element(param), index=4 + constant = u32[] constant(4) + ROOT compare = pred[] compare(get-tuple-element, constant), direction=LT +} + +ENTRY main.12_spmd { + param.4 = bf16[16,2048,512]{2,1,0} parameter(0) + param.5 = bf16[4096,6288]{1,0} parameter(1) + constant.22 = bf16[] constant(0) + broadcast = bf16[16,2048,6288]{2,1,0} broadcast(constant.22), dimensions={} + constant.24 = u32[] constant(0) + tuple.2 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24) + while = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag + get-tuple-element.13 = bf16[16,2048,6288]{2,1,0} get-tuple-element(while), index=2 + all-gather = bf16[16,2048,4096]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={2}, use_global_device_ids=true + param.6 = bf16[16,2048,6288]{2,1,0} parameter(2) + ROOT dot.7 = bf16[4096,6288]{1,0} dot(all-gather, param.6), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + GpuWindowedEinsumHandler gpu_handler; + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); + EXPECT_TRUE(changed); + + HloInstruction* ag_loop = + FindInstructionByName(module->entry_computation(), "while"); + HloInstruction* inst = + FindInstructionByName(module->entry_computation(), "dot.7"); + // dot.7 should now consume output of the windowed einsum while loop. + EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement); + EXPECT_EQ(inst->operand(0)->tuple_index(), 5); + EXPECT_EQ(inst->operand(0)->operand(0), ag_loop); +} } // namespace } // namespace xla::gpu From 7343933df4f96affee731371c674782409677fa3 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Tue, 25 Jun 2024 17:40:21 -0700 Subject: [PATCH 2/4] PR #14073: Add select(compare(a, b, GT/GE), a, b) => or(a, b) to algsimp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/14073 In one of the customer's HLOs (in the reduceOp computation function), I found the following pattern: ```c p0 = pred[]{0} parameter(0) p1 = pred[]{0} parameter(1) compare = pred[]{0} compare(p0, p1), direction=GT select = pred[]{0} select(compare, p0, p1) ``` It can be simplified to `logical_or`. This PR adds the following patterns to algsimp ```c select(compare(a, b, GT/GE), a, b) => or(a, b) select(compare(a, b, LT/LE), a, b) => and(a, b) select(compare(a, b, EQ), a, b) => b select(compare(a, b, NE), a, b) => a a,b ∈ PRED ``` Copybara import of the project: -- 6fe68d7319b272ff041b67e038359540cddda489 by Alexander Pivovarov : Add select(compare(a, b, GT/GE), a, b) => or(a, b) to algsimp Merging this change closes #14073 PiperOrigin-RevId: 646667024 --- .../xla/xla/service/algebraic_simplifier.cc | 27 ++++++ .../xla/service/algebraic_simplifier_test.cc | 89 +++++++++++++++++++ 2 files changed, 116 insertions(+) diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index dd13ffebd9e852..641fedf0c72405 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -8188,6 +8188,33 @@ absl::Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) { select->mutable_operand(0)->shape(), HloOpcode::kNot, select->mutable_operand(0))); } + // select(compare(a, b, GT/GE), a, b) => or(a, b) + // select(compare(a, b, LT/LE), a, b) => and(a, b) + // select(compare(a, b, EQ), a, b) => b + // select(compare(a, b, NE), a, b) => a + HloInstruction *compare, *lhs, *rhs; + if (Match(select, m::Select(m::Op(&compare), m::Op(&lhs), m::Op(&rhs))) && + Match(compare, m::Compare(m::Op().Is(lhs), m::Op().Is(rhs)))) { + auto cmp_dir = compare->comparison_direction(); + if (cmp_dir == ComparisonDirection::kGt || + cmp_dir == ComparisonDirection::kGe) { + return ReplaceWithNewInstruction( + select, HloInstruction::CreateBinary(select->shape(), + HloOpcode::kOr, lhs, rhs)); + } + if (cmp_dir == ComparisonDirection::kLt || + cmp_dir == ComparisonDirection::kLe) { + return ReplaceWithNewInstruction( + select, HloInstruction::CreateBinary(select->shape(), + HloOpcode::kAnd, lhs, rhs)); + } + if (cmp_dir == ComparisonDirection::kEq) { + return ReplaceInstruction(select, rhs); + } + if (cmp_dir == ComparisonDirection::kNe) { + return ReplaceInstruction(select, lhs); + } + } } // select(pred, xs, dynamic_update_slice(xs, x, i)) diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index 00970a51546b1a..921098aa7565e8 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -736,6 +736,95 @@ TEST_F(AlgebraicSimplifierTest, SelectPredPred2) { GmockMatch(m::Not(m::Parameter(0)))); } +// select(compare(a, b, GT/GE), a, b) => or(a, b), a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectGtCompare) { + for (const auto cmp_dir : {"GT", "GE"}) { + const auto kModuleStr = absl::StrFormat(R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=%s + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )", + cmp_dir); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Or(m::Parameter(0), m::Parameter(1)))); + } +} + +// select(compare(a, b, LT/LE), a, b) => and(a, b), a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectLtCompare) { + for (const auto cmp_dir : {"LT", "LE"}) { + const auto kModuleStr = absl::StrFormat(R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=%s + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )", + cmp_dir); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::And(m::Parameter(0), m::Parameter(1)))); + } +} + +// select(compare(a, b, EQ), a, b) => b, a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectEqCompare) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=EQ + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(1))); +} + +// select(compare(a, b, NE), a, b) => a, a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectNeCompare) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=NE + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} + +// select(compare(a, b, NE), b, a) ≠> a - wrong operands order +TEST_F(AlgebraicSimplifierTest, SelectNeCompare_NegativeTestCase) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=NE + ROOT select = pred[8]{0} select(compare, p1, p0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + // Test that select(pred, xs, dynamic_update_slice(xs, x, i)) is simplified // to dynamic_update_slice(xs, select(pred, dynamic_slice(xs, i), x), i) TEST_F(AlgebraicSimplifierTest, SelectDUSWithShapedPred) { From 60a8e1a2d4cebeee0424bf4e38ccff2b61ef9b40 Mon Sep 17 00:00:00 2001 From: Laura Pak Date: Tue, 25 Jun 2024 17:51:51 -0700 Subject: [PATCH 3/4] Move tensorflow/lite/tools/optimize/testdata models to tf/compiler/mlir/lite/quantization/lite/testdata. PiperOrigin-RevId: 646669412 --- .../mlir/lite/quantization/lite/BUILD | 93 ++++++++------- .../quantization/lite/quantize_model_test.cc | 65 ++++++----- .../lite/quantize_weights_test.cc | 17 +-- .../mlir/lite/quantization/lite}/test_util.cc | 10 +- .../mlir/lite/quantization/lite}/test_util.h | 16 +-- .../quantization/lite}/testdata/README.md | 0 .../lite}/testdata/add_with_const_input.bin | Bin .../quantization/lite}/testdata/argmax.bin | Bin .../lite}/testdata/broadcast_to.bin | Bin .../quantization/lite}/testdata/concat.bin | Bin .../quantization/lite}/testdata/custom_op.bin | Bin .../lite/quantization/lite}/testdata/fc.bin | Bin .../quantization/lite}/testdata/fc_qat.bin | Bin .../quantization/lite}/testdata/gather_nd.bin | Bin .../lite}/testdata/lstm_calibrated.bin | Bin .../lite}/testdata/lstm_calibrated2.bin | Bin .../lite}/testdata/lstm_quantized.bin | Bin .../lite}/testdata/lstm_quantized2.bin | Bin .../quantization/lite}/testdata/maximum.bin | Bin .../quantization/lite}/testdata/minimum.bin | Bin .../quantization/lite}/testdata/mixed.bin | Bin .../quantization/lite}/testdata/mixed16x8.bin | Bin .../testdata/multi_input_add_reshape.bin | Bin .../lite/quantization/lite}/testdata/pack.bin | Bin .../lite}/testdata/quantized_with_gather.bin | Bin .../testdata/resource_vars_calibrated.bin | Bin ...single_avg_pool_min_minus_5_max_plus_5.bin | Bin .../lite}/testdata/single_conv_no_bias.bin | Bin .../single_conv_weights_min_0_max_plus_10.bin | Bin ...onv_weights_min_minus_127_max_plus_127.bin | Bin .../single_softmax_min_minus_5_max_plus_5.bin | Bin .../quantization/lite}/testdata/split.bin | Bin .../lite}/testdata/svdf_calibrated.bin | Bin .../lite}/testdata/svdf_quantized.bin | Bin .../quantization/lite}/testdata/transpose.bin | Bin ...nidirectional_sequence_lstm_calibrated.bin | Bin ...unidirectional_sequence_lstm_quantized.bin | Bin .../quantization/lite}/testdata/unpack.bin | Bin .../testdata/weight_shared_between_convs.bin | Bin .../quantization/lite}/testdata/where.bin | Bin tensorflow/lite/tools/optimize/BUILD | 103 +++++++---------- .../lite/tools/optimize/model_utils_test.cc | 1 - .../tools/optimize/quantization_utils_test.cc | 4 +- .../tools/optimize/quantize_model_test.cc | 106 +++++++++++------- .../tools/optimize/quantize_weights_test.cc | 16 +-- .../reduced_precision_support_test.cc | 1 - 46 files changed, 228 insertions(+), 204 deletions(-) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/test_util.cc (95%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/test_util.h (93%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/README.md (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/add_with_const_input.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/argmax.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/broadcast_to.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/concat.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/custom_op.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/fc.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/fc_qat.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/gather_nd.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/lstm_calibrated.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/lstm_calibrated2.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/lstm_quantized.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/lstm_quantized2.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/maximum.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/minimum.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/mixed.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/mixed16x8.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/multi_input_add_reshape.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/pack.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/quantized_with_gather.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/resource_vars_calibrated.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/single_avg_pool_min_minus_5_max_plus_5.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/single_conv_no_bias.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/single_conv_weights_min_0_max_plus_10.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/single_conv_weights_min_minus_127_max_plus_127.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/single_softmax_min_minus_5_max_plus_5.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/split.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/svdf_calibrated.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/svdf_quantized.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/transpose.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/unidirectional_sequence_lstm_calibrated.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/unidirectional_sequence_lstm_quantized.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/unpack.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/weight_shared_between_convs.bin (100%) rename tensorflow/{lite/tools/optimize => compiler/mlir/lite/quantization/lite}/testdata/where.bin (100%) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 78ae512eb54202..e57c2b30808d82 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -10,6 +10,10 @@ package( licenses = ["notice"], ) +exports_files(glob([ + "testdata/*.bin", +])) + package_group( name = "friends", packages = [ @@ -123,39 +127,39 @@ tf_cc_test( name = "quantize_model_test", srcs = ["quantize_model_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin", - "//tensorflow/lite/tools/optimize:testdata/argmax.bin", - "//tensorflow/lite/tools/optimize:testdata/broadcast_to.bin", - "//tensorflow/lite/tools/optimize:testdata/concat.bin", - "//tensorflow/lite/tools/optimize:testdata/fc.bin", - "//tensorflow/lite/tools/optimize:testdata/fc_qat.bin", - "//tensorflow/lite/tools/optimize:testdata/gather_nd.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized2.bin", - "//tensorflow/lite/tools/optimize:testdata/maximum.bin", - "//tensorflow/lite/tools/optimize:testdata/minimum.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed16x8.bin", - "//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin", - "//tensorflow/lite/tools/optimize:testdata/pack.bin", - "//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_no_bias.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", - "//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/split.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/transpose.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/unpack.bin", - "//tensorflow/lite/tools/optimize:testdata/where.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/add_with_const_input.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/argmax.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/broadcast_to.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/concat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc_qat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/gather_nd.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/maximum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/minimum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed16x8.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/multi_input_add_reshape.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/pack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_no_bias.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_softmax_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/split.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/transpose.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unpack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/where.bin", ], tags = [ "tflite_not_portable_android", @@ -163,12 +167,12 @@ tf_cc_test( ], deps = [ ":quantize_model", + ":test_util", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", - "//tensorflow/lite/tools/optimize:test_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", @@ -181,13 +185,13 @@ tf_cc_test( name = "quantize_weights_test", srcs = ["quantize_weights_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/custom_op.bin", - "//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/custom_op.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/quantized_with_gather.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/weight_shared_between_convs.bin", ], tags = [ # TODO(b/327796566): re-enable after the bug is fixed @@ -200,15 +204,28 @@ tf_cc_test( ], deps = [ ":quantize_weights", + ":test_util", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", - "//tensorflow/lite/tools/optimize:test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@flatbuffers", "@local_tsl//tsl/platform:logging", ], ) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + "//tensorflow/lite:framework", + "//tensorflow/lite/core/api", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc index e7d5e00b703392..1e7cdcdea07d33 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/status.h" #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include "tensorflow/core/platform/init_main.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/tools/optimize/test_util.h" #include "tsl/lib/core/status_test_util.h" // Note: branched from tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -192,7 +192,8 @@ void VerifyQuantizationScale( class QuantizeModelTest : public testing::Test { protected: QuantizeModelTest() { - input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); + input_model_ = + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -277,7 +278,8 @@ class QuantizeConvModelTest : public QuantizeModelTest, protected: QuantizeConvModelTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); + input_model_ = + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); // Flatbuffer is missing calibration data -- add dummy params. @@ -347,7 +349,7 @@ TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) { class QuantizeConvNoBiasModelTest : public QuantizeModelTest { protected: QuantizeConvNoBiasModelTest() { - input_model_ = ReadModel(internal::kConvModelWithNoBias); + input_model_ = ReadModel(::mlir::lite::internal::kConvModelWithNoBias); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -367,7 +369,7 @@ class QuantizeConvNoBiasModelTest : public QuantizeModelTest { class QuantizeSplitModelTest : public QuantizeModelTest { protected: QuantizeSplitModelTest() { - input_model_ = ReadModel(internal::kModelSplit); + input_model_ = ReadModel(::mlir::lite::internal::kModelSplit); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -452,7 +454,8 @@ class QuantizeConvModel2Test : public QuantizeModelTest, protected: QuantizeConvModel2Test() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); + input_model_ = + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); auto& subgraph = model_.subgraphs[0]; @@ -690,7 +693,8 @@ TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) { class QuantizeSoftmaxTest : public QuantizeModelTest { protected: QuantizeSoftmaxTest() { - input_model_ = ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5); + input_model_ = + ReadModel(::mlir::lite::internal::kSingleSoftmaxModelMinMinus5MaxPlus5); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -753,7 +757,8 @@ TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { class QuantizeAvgPoolTest : public QuantizeModelTest { protected: QuantizeAvgPoolTest() { - input_model_ = ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5); + input_model_ = + ReadModel(::mlir::lite::internal::kSingleAvgPoolModelMinMinus5MaxPlus5); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -813,7 +818,7 @@ TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest { protected: QuantizeMultiInputAddWithReshapeTest() { - input_model_ = ReadModel(internal::kMultiInputAddWithReshape); + input_model_ = ReadModel(::mlir::lite::internal::kMultiInputAddWithReshape); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -933,7 +938,7 @@ class QuantizeConstInputTest : public QuantizeModelTest, protected: QuantizeConstInputTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kConstInputAddModel); + input_model_ = ReadModel(::mlir::lite::internal::kConstInputAddModel); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -980,7 +985,7 @@ TEST_P(QuantizeConstInputTest, VerifyConstOpInput) { class QuantizeArgMaxTest : public QuantizeModelTest { protected: QuantizeArgMaxTest() { - input_model_ = ReadModel(internal::kModelWithArgMaxOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithArgMaxOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1025,7 +1030,7 @@ TEST_F(QuantizeArgMaxTest, VerifyArgMax) { class QuantizeLSTMTest : public QuantizeModelTest { protected: QuantizeLSTMTest() { - input_model_ = ReadModel(internal::kLstmCalibrated); + input_model_ = ReadModel(::mlir::lite::internal::kLstmCalibrated); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1037,7 +1042,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { /*allow_float=*/true, TensorType_INT8, output_buffer_)); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1048,7 +1053,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { class QuantizeLSTM2Test : public QuantizeModelTest { protected: QuantizeLSTM2Test() { - input_model_ = ReadModel(internal::kLstmCalibrated2); + input_model_ = ReadModel(::mlir::lite::internal::kLstmCalibrated2); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1061,7 +1066,7 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { /*allow_float=*/false, TensorType_INT8, output_buffer_)); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized2); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized2); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1072,7 +1077,8 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest { protected: QuantizeUnidirectionalSequenceLSTMTest() { - input_model_ = ReadModel(internal::kUnidirectionalSequenceLstmCalibrated); + input_model_ = ReadModel( + ::mlir::lite::internal::kUnidirectionalSequenceLstmCalibrated); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1086,7 +1092,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, // Read expected model. auto expected_fb_model = - ReadModel(internal::kUnidirectionalSequenceLstmQuantized); + ReadModel(::mlir::lite::internal::kUnidirectionalSequenceLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1097,7 +1103,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, class QuantizeSVDFTest : public QuantizeModelTest { protected: QuantizeSVDFTest() { - input_model_ = ReadModel(internal::kSvdfCalibrated); + input_model_ = ReadModel(::mlir::lite::internal::kSvdfCalibrated); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1110,7 +1116,7 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { /*allow_float=*/false, TensorType_INT8, output_buffer_)); // Read expected model. - auto expected_fb_model = ReadModel(internal::kSvdfQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kSvdfQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1123,7 +1129,7 @@ class QuantizeFCTest : public QuantizeModelTest, protected: QuantizeFCTest() { disable_per_channel_quantization_for_dense_ = GetParam(); - input_model_ = ReadModel(internal::kModelWithFCOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithFCOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1371,7 +1377,7 @@ class QuantizeCustomOpTest public ::testing::WithParamInterface { protected: QuantizeCustomOpTest() { - input_model_ = ReadModel(internal::kModelMixed); + input_model_ = ReadModel(::mlir::lite::internal::kModelMixed); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1409,7 +1415,7 @@ INSTANTIATE_TEST_SUITE_P(QuantizeCustomOpTest, QuantizeCustomOpTest, class QuantizePackTest : public QuantizeModelTest { protected: QuantizePackTest() { - input_model_ = ReadModel(internal::kModelPack); + input_model_ = ReadModel(::mlir::lite::internal::kModelPack); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1526,14 +1532,15 @@ TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { Eq(input2->quantization->zero_point)); } -INSTANTIATE_TEST_SUITE_P(MinimumMaximumTestInst, QuantizeMinimumMaximumTest, - testing::ValuesIn({internal::kModelWithMinimumOp, - internal::kModelWithMaximumOp})); +INSTANTIATE_TEST_SUITE_P( + MinimumMaximumTestInst, QuantizeMinimumMaximumTest, + testing::ValuesIn({::mlir::lite::internal::kModelWithMinimumOp, + ::mlir::lite::internal::kModelWithMaximumOp})); class QuantizeUnpackTest : public QuantizeModelTest { protected: QuantizeUnpackTest() { - input_model_ = ReadModel(internal::kModelWithUnpack); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithUnpack); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1583,7 +1590,7 @@ class QuantizeBroadcastToModelTest protected: QuantizeBroadcastToModelTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kModelWithBroadcastToOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithBroadcastToOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1646,7 +1653,7 @@ class QuantizeGatherNDModelTest protected: QuantizeGatherNDModelTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kModelWithGatherNDOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithGatherNDOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1706,7 +1713,7 @@ TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) { class QuantizeWhereModelTest : public QuantizeModelTest { protected: QuantizeWhereModelTest() { - input_model_ = ReadModel(internal::kModelWithWhereOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithWhereOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc index 2e80bcae7486b4..7a42e74c2619af 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h" #include +#include #include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include "tensorflow/core/platform/init_main.h" @@ -33,7 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/tools/optimize/test_util.h" #include "tsl/platform/logging.h" // Note: branched from tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -59,25 +60,25 @@ std::unique_ptr CreateMutableModelFromFile(const Model* input_model) { std::unique_ptr ReadTestModel() { auto model_path = tensorflow::io::JoinPath( - *g_test_model_dir, internal::kConvModelWith0Plus10Weights); + *g_test_model_dir, ::mlir::lite::internal::kConvModelWith0Plus10Weights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadSharedWeightsTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kModelWithSharedWeights); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithSharedWeights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadGatherTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kQuantizedWithGather); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kQuantizedWithGather); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadCustomOpTestModel() { - auto model_path = - tensorflow::io::JoinPath(*g_test_model_dir, internal::kModelWithCustomOp); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithCustomOp); return FlatBufferModel::BuildFromFile(model_path.c_str()); } diff --git a/tensorflow/lite/tools/optimize/test_util.cc b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc similarity index 95% rename from tensorflow/lite/tools/optimize/test_util.cc rename to tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc index 5ca45326d1dcad..e096868eec8807 100644 --- a/tensorflow/lite/tools/optimize/test_util.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/tools/optimize/test_util.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include -namespace tflite { -namespace optimize { +namespace mlir { +namespace lite { namespace internal { const char* kConvModelWithMinus128Plus127Weights = "single_conv_weights_min_minus_127_max_plus_127.bin"; @@ -89,5 +89,5 @@ int FailOnErrorReporter::Report(const char* format, va_list args) { return 0; } } // namespace internal -} // namespace optimize -} // namespace tflite +} // namespace lite +} // namespace mlir diff --git a/tensorflow/lite/tools/optimize/test_util.h b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h similarity index 93% rename from tensorflow/lite/tools/optimize/test_util.h rename to tensorflow/compiler/mlir/lite/quantization/lite/test_util.h index 11e7ef230910f2..b4e317c131888e 100644 --- a/tensorflow/lite/tools/optimize/test_util.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_TEST_UTIL_H_ -#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_TEST_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ #include "tensorflow/lite/core/api/error_reporter.h" -namespace tflite { -namespace optimize { +namespace mlir { +namespace lite { namespace internal { // Test model with a single convolution. // Floating point weights of the model are all integers and lie in @@ -132,12 +132,12 @@ extern const char* kQatModelWithFc; extern const char* kModelWithResourceVarsCalibrated; // An error reporter that fails on testing. -class FailOnErrorReporter : public ErrorReporter { +class FailOnErrorReporter : public tflite::ErrorReporter { public: int Report(const char* format, va_list args) override; }; } // namespace internal -} // namespace optimize -} // namespace tflite +} // namespace lite +} // namespace mlir -#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_TEST_UTIL_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ diff --git a/tensorflow/lite/tools/optimize/testdata/README.md b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/README.md similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/README.md rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/README.md diff --git a/tensorflow/lite/tools/optimize/testdata/add_with_const_input.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/add_with_const_input.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/add_with_const_input.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/add_with_const_input.bin diff --git a/tensorflow/lite/tools/optimize/testdata/argmax.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/argmax.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/argmax.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/argmax.bin diff --git a/tensorflow/lite/tools/optimize/testdata/broadcast_to.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/broadcast_to.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/broadcast_to.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/broadcast_to.bin diff --git a/tensorflow/lite/tools/optimize/testdata/concat.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/concat.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/concat.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/concat.bin diff --git a/tensorflow/lite/tools/optimize/testdata/custom_op.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/custom_op.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/custom_op.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/custom_op.bin diff --git a/tensorflow/lite/tools/optimize/testdata/fc.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/fc.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc.bin diff --git a/tensorflow/lite/tools/optimize/testdata/fc_qat.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc_qat.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/fc_qat.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc_qat.bin diff --git a/tensorflow/lite/tools/optimize/testdata/gather_nd.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/gather_nd.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/gather_nd.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/gather_nd.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_calibrated2.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated2.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_calibrated2.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated2.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_quantized.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_quantized.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_quantized2.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized2.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_quantized2.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized2.bin diff --git a/tensorflow/lite/tools/optimize/testdata/maximum.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/maximum.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/maximum.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/maximum.bin diff --git a/tensorflow/lite/tools/optimize/testdata/minimum.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/minimum.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/minimum.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/minimum.bin diff --git a/tensorflow/lite/tools/optimize/testdata/mixed.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/mixed.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed.bin diff --git a/tensorflow/lite/tools/optimize/testdata/mixed16x8.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed16x8.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/mixed16x8.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed16x8.bin diff --git a/tensorflow/lite/tools/optimize/testdata/multi_input_add_reshape.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/multi_input_add_reshape.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/multi_input_add_reshape.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/multi_input_add_reshape.bin diff --git a/tensorflow/lite/tools/optimize/testdata/pack.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/pack.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/pack.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/pack.bin diff --git a/tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/quantized_with_gather.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/quantized_with_gather.bin diff --git a/tensorflow/lite/tools/optimize/testdata/resource_vars_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/resource_vars_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/resource_vars_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/resource_vars_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_avg_pool_min_minus_5_max_plus_5.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_avg_pool_min_minus_5_max_plus_5.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_avg_pool_min_minus_5_max_plus_5.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_avg_pool_min_minus_5_max_plus_5.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_conv_no_bias.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_no_bias.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_conv_no_bias.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_no_bias.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_0_max_plus_10.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_0_max_plus_10.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_0_max_plus_10.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_0_max_plus_10.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_minus_127_max_plus_127.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_minus_127_max_plus_127.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_minus_127_max_plus_127.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_minus_127_max_plus_127.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_softmax_min_minus_5_max_plus_5.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_softmax_min_minus_5_max_plus_5.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_softmax_min_minus_5_max_plus_5.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_softmax_min_minus_5_max_plus_5.bin diff --git a/tensorflow/lite/tools/optimize/testdata/split.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/split.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/split.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/split.bin diff --git a/tensorflow/lite/tools/optimize/testdata/svdf_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/svdf_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/svdf_quantized.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_quantized.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/svdf_quantized.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_quantized.bin diff --git a/tensorflow/lite/tools/optimize/testdata/transpose.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/transpose.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/transpose.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/transpose.bin diff --git a/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_quantized.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_quantized.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_quantized.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_quantized.bin diff --git a/tensorflow/lite/tools/optimize/testdata/unpack.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/unpack.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/unpack.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/unpack.bin diff --git a/tensorflow/lite/tools/optimize/testdata/weight_shared_between_convs.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/weight_shared_between_convs.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/weight_shared_between_convs.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/weight_shared_between_convs.bin diff --git a/tensorflow/lite/tools/optimize/testdata/where.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/where.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/where.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/where.bin diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index 711f97bdfddd16..a05a5cbdb10710 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -14,10 +14,6 @@ package( licenses = ["notice"], ) -exports_files(glob([ - "testdata/*.bin", -])) - cc_library( name = "reduced_precision_support", srcs = [], @@ -39,7 +35,6 @@ tf_cc_test( ], deps = [ ":reduced_precision_support", - ":test_util", "//tensorflow/core/platform:platform_port", "//tensorflow/lite:framework", "//tensorflow/lite/core:framework", @@ -223,7 +218,6 @@ tf_cc_test( ], deps = [ ":model_utils", - ":test_util", "//tensorflow/lite:framework", "//tensorflow/lite/core:framework", "//tensorflow/lite/schema:schema_fbs", @@ -250,10 +244,10 @@ tf_cc_test( name = "quantization_utils_test", srcs = ["quantization_utils_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - ":testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", ], tags = [ "tflite_not_portable_android", @@ -261,7 +255,7 @@ tf_cc_test( ], deps = [ ":quantization_utils", - ":test_util", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", @@ -316,13 +310,13 @@ tf_cc_test( name = "quantize_weights_test", srcs = ["quantize_weights_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/custom_op.bin", - "//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/custom_op.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/quantized_with_gather.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/weight_shared_between_convs.bin", ], tags = [ "tflite_not_portable_android", @@ -330,7 +324,7 @@ tf_cc_test( ], deps = [ ":quantize_weights", - ":test_util", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", @@ -342,19 +336,6 @@ tf_cc_test( ], ) -cc_library( - name = "test_util", - testonly = 1, - srcs = ["test_util.cc"], - hdrs = ["test_util.h"], - deps = [ - "//tensorflow/lite:framework", - "//tensorflow/lite/core/api", - "@com_google_googletest//:gtest", - "@flatbuffers", - ], -) - cc_library( name = "quantize_model", srcs = ["quantize_model.cc"], @@ -379,40 +360,40 @@ tf_cc_test( name = "quantize_model_test", srcs = ["quantize_model_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin", - "//tensorflow/lite/tools/optimize:testdata/argmax.bin", - "//tensorflow/lite/tools/optimize:testdata/broadcast_to.bin", - "//tensorflow/lite/tools/optimize:testdata/concat.bin", - "//tensorflow/lite/tools/optimize:testdata/fc.bin", - "//tensorflow/lite/tools/optimize:testdata/fc_qat.bin", - "//tensorflow/lite/tools/optimize:testdata/gather_nd.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized2.bin", - "//tensorflow/lite/tools/optimize:testdata/maximum.bin", - "//tensorflow/lite/tools/optimize:testdata/minimum.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed16x8.bin", - "//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin", - "//tensorflow/lite/tools/optimize:testdata/pack.bin", - "//tensorflow/lite/tools/optimize:testdata/resource_vars_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_no_bias.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", - "//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/split.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/transpose.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/unpack.bin", - "//tensorflow/lite/tools/optimize:testdata/where.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/add_with_const_input.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/argmax.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/broadcast_to.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/concat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc_qat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/gather_nd.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/maximum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/minimum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed16x8.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/multi_input_add_reshape.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/pack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/resource_vars_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_no_bias.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_softmax_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/split.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/transpose.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unpack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/where.bin", ], tags = [ "tflite_not_portable_android", @@ -420,7 +401,7 @@ tf_cc_test( ], deps = [ ":quantize_model", - ":test_util", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", diff --git a/tensorflow/lite/tools/optimize/model_utils_test.cc b/tensorflow/lite/tools/optimize/model_utils_test.cc index 65e3afe35e2da2..f702e1fa0a0ddd 100644 --- a/tensorflow/lite/tools/optimize/model_utils_test.cc +++ b/tensorflow/lite/tools/optimize/model_utils_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include "tensorflow/lite/core/model.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace tflite { namespace optimize { diff --git a/tensorflow/lite/tools/optimize/quantization_utils_test.cc b/tensorflow/lite/tools/optimize/quantization_utils_test.cc index a09acef6f4aa3c..a0ab9c43eacb75 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils_test.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include "absl/memory/memory.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace { tensorflow::string* g_test_model_dir = nullptr; @@ -46,7 +46,7 @@ std::unique_ptr ReadModel(const char* model) { } std::unique_ptr ReadConvModel() { - return ReadModel(internal::kConvModelWith0Plus10Weights); + return ReadModel(mlir::lite::internal::kConvModelWith0Plus10Weights); } using ::testing::ElementsAreArray; diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc index 681507c8e0d31d..a7e9115f8bdaaa 100644 --- a/tensorflow/lite/tools/optimize/quantize_model_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/lite/tools/optimize/test_util.h" // Note: More rigorous model tests can be found in subgraph_quantizer_test.cc @@ -78,7 +78,8 @@ TensorType GetBiasTensorType(TensorType& activation_type) { class QuantizeModelTest : public testing::Test { protected: QuantizeModelTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)) {} + : QuantizeModelTest( + ReadModel(mlir::lite::internal::kConvModelWith0Plus10Weights)) {} explicit QuantizeModelTest(std::unique_ptr input_model) { input_model_ = std::move(input_model); @@ -91,7 +92,7 @@ class QuantizeModelTest : public testing::Test { const Model* readonly_model_; tflite::ModelT model_; flatbuffers::FlatBufferBuilder builder_; - internal::FailOnErrorReporter error_reporter_; + ::mlir::lite::internal::FailOnErrorReporter error_reporter_; }; void ExpectSameModels(const ModelT& model, const ModelT& expected_model) { @@ -136,7 +137,8 @@ class QuantizeConvModelTest : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConvModelTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} TensorType tensor_type_; @@ -405,7 +407,8 @@ TEST_P(QuantizeConvModelTest, Uint8InputAndOutput) { class QuantizeConvNoBiasModelTest : public QuantizeModelTest { protected: QuantizeConvNoBiasModelTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWithNoBias)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWithNoBias)) {} }; TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) { @@ -422,7 +425,8 @@ class QuantizeConcatModelTest : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConcatModelTest() - : QuantizeModelTest(ReadModel(internal::kFloatConcatMax5Max10Max10)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kFloatConcatMax5Max10Max10)) {} void SetUp() override { tensor_type_ = GetParam(); @@ -536,7 +540,7 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConcatModelInst, QuantizeConcatModelTest, class QuantizeSplitModelTest : public QuantizeModelTest { protected: QuantizeSplitModelTest() - : QuantizeModelTest(ReadModel(internal::kModelSplit)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelSplit)) {} }; // There are two outputs for split with different scales, the resulting model @@ -601,8 +605,8 @@ TEST_F(QuantizeSplitModelTest, QuantizeSplit) { class QuantizeConvModel1Test : public QuantizeModelTest { protected: QuantizeConvModel1Test() - : QuantizeModelTest( - ReadModel(internal::kConvModelWithMinus128Plus127Weights)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kConvModelWithMinus128Plus127Weights)) {} }; TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) { @@ -703,7 +707,8 @@ class QuantizeConvModel2Test : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConvModel2Test() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -925,8 +930,8 @@ TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) { class QuantizeSoftmaxTest : public QuantizeModelTest { protected: QuantizeSoftmaxTest() - : QuantizeModelTest( - ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kSingleSoftmaxModelMinMinus5MaxPlus5)) {} }; TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { @@ -985,8 +990,8 @@ TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { class QuantizeAvgPoolTest : public QuantizeModelTest { protected: QuantizeAvgPoolTest() - : QuantizeModelTest( - ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kSingleAvgPoolModelMinMinus5MaxPlus5)) {} }; TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { @@ -1045,7 +1050,8 @@ TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest { protected: QuantizeMultiInputAddWithReshapeTest() - : QuantizeModelTest(ReadModel(internal::kMultiInputAddWithReshape)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kMultiInputAddWithReshape)) {} }; TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { @@ -1155,7 +1161,8 @@ class QuantizeConstInputTest : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConstInputTest() - : QuantizeModelTest(ReadModel(internal::kConstInputAddModel)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConstInputAddModel)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1213,7 +1220,8 @@ TEST_P(QuantizeConstInputTest, VerifyConstOpInput) { class QuantizeArgMaxTest : public QuantizeModelTest { protected: QuantizeArgMaxTest() - : QuantizeModelTest(ReadModel(internal::kModelWithArgMaxOp)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithArgMaxOp)) {} }; TEST_F(QuantizeArgMaxTest, VerifyArgMax) { @@ -1254,7 +1262,7 @@ TEST_F(QuantizeArgMaxTest, VerifyArgMax) { class QuantizeLSTMTest : public QuantizeModelTest { protected: QuantizeLSTMTest() - : QuantizeModelTest(ReadModel(internal::kLstmCalibrated)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kLstmCalibrated)) {} }; TEST_F(QuantizeLSTMTest, VerifyLSTM) { @@ -1265,7 +1273,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { ASSERT_EQ(kTfLiteOk, status); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1276,7 +1284,8 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { class QuantizeLSTM2Test : public QuantizeModelTest { protected: QuantizeLSTM2Test() - : QuantizeModelTest(ReadModel(internal::kLstmCalibrated2)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kLstmCalibrated2)) { + } }; TEST_F(QuantizeLSTM2Test, VerifyLSTM) { @@ -1287,7 +1296,7 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { ASSERT_EQ(kTfLiteOk, status); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized2); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized2); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1298,8 +1307,8 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest { protected: QuantizeUnidirectionalSequenceLSTMTest() - : QuantizeModelTest( - ReadModel(internal::kUnidirectionalSequenceLstmCalibrated)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kUnidirectionalSequenceLstmCalibrated)) {} }; TEST_F(QuantizeUnidirectionalSequenceLSTMTest, @@ -1312,7 +1321,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, // Read expected model. auto expected_fb_model = - ReadModel(internal::kUnidirectionalSequenceLstmQuantized); + ReadModel(::mlir::lite::internal::kUnidirectionalSequenceLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1323,7 +1332,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, class QuantizeSVDFTest : public QuantizeModelTest { protected: QuantizeSVDFTest() - : QuantizeModelTest(ReadModel(internal::kSvdfCalibrated)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kSvdfCalibrated)) {} }; TEST_F(QuantizeSVDFTest, VerifySVDF) { @@ -1334,7 +1343,7 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { ASSERT_EQ(kTfLiteOk, status); // Read expected model. - auto expected_fb_model = ReadModel(internal::kSvdfQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kSvdfQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1379,7 +1388,8 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { class QuantizeFCTest : public QuantizeModelTest { protected: - QuantizeFCTest() : QuantizeModelTest(ReadModel(internal::kModelWithFCOp)) {} + QuantizeFCTest() + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelWithFCOp)) {} }; TEST_F(QuantizeFCTest, VerifyFC) { @@ -1430,7 +1440,7 @@ class QuantizeCustomOpTest public ::testing::WithParamInterface { protected: QuantizeCustomOpTest() - : QuantizeModelTest(ReadModel(internal::kModelMixed)), + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelMixed)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1471,7 +1481,7 @@ INSTANTIATE_TEST_SUITE_P(QuantizeCustomOpTest, QuantizeCustomOpTest, class QuantizeOp16x8Test : public QuantizeModelTest { protected: QuantizeOp16x8Test() - : QuantizeModelTest(ReadModel(internal::kModelMixed16x8)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelMixed16x8)) {} }; TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) { @@ -1502,7 +1512,8 @@ TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) { class QuantizePackTest : public QuantizeModelTest { protected: - QuantizePackTest() : QuantizeModelTest(ReadModel(internal::kModelPack)) {} + QuantizePackTest() + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelPack)) {} }; TEST_F(QuantizePackTest, VerifyPack) { @@ -1628,14 +1639,16 @@ TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { EXPECT_EQ(subgraph->tensors[5]->name, "output"); } -INSTANTIATE_TEST_SUITE_P(MinimumMaximumTestInst, QuantizeMinimumMaximumTest, - testing::ValuesIn({internal::kModelWithMinimumOp, - internal::kModelWithMaximumOp})); +INSTANTIATE_TEST_SUITE_P( + MinimumMaximumTestInst, QuantizeMinimumMaximumTest, + testing::ValuesIn({::mlir::lite::internal::kModelWithMinimumOp, + ::mlir::lite::internal::kModelWithMaximumOp})); class QuantizeUnpackTest : public QuantizeModelTest { protected: QuantizeUnpackTest() - : QuantizeModelTest(ReadModel(internal::kModelWithUnpack)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelWithUnpack)) { + } }; TEST_F(QuantizeUnpackTest, VerifyUnpack) { auto status = QuantizeModel(&builder_, &model_, &error_reporter_); @@ -1680,7 +1693,8 @@ TEST_F(QuantizeUnpackTest, VerifyUnpack) { class QuantizeTransposeTest : public QuantizeModelTest { protected: QuantizeTransposeTest() - : QuantizeModelTest(ReadModel(internal::kModelWithTranspose)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithTranspose)) {} }; TEST_F(QuantizeTransposeTest, VerifyTranspose) { @@ -1720,7 +1734,8 @@ TEST_F(QuantizeTransposeTest, VerifyTranspose) { class QuantizeQatTest : public QuantizeModelTest { protected: - QuantizeQatTest() : QuantizeModelTest(ReadModel(internal::kQatModelWithFc)) {} + QuantizeQatTest() + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kQatModelWithFc)) {} }; TEST_F(QuantizeQatTest, VerifySingleQuantize) { @@ -1777,7 +1792,8 @@ class QuantizeBroadcastToModelTest public testing::WithParamInterface { protected: QuantizeBroadcastToModelTest() - : QuantizeModelTest(ReadModel(internal::kModelWithBroadcastToOp)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithBroadcastToOp)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1844,7 +1860,8 @@ class QuantizeGatherNDModelTest public testing::WithParamInterface { protected: QuantizeGatherNDModelTest() - : QuantizeModelTest(ReadModel(internal::kModelWithGatherNDOp)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithGatherNDOp)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1906,7 +1923,8 @@ TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) { class QuantizeWhereModelTest : public QuantizeModelTest { protected: QuantizeWhereModelTest() - : QuantizeModelTest(ReadModel(internal::kModelWithWhereOp)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithWhereOp)) {} }; TEST_F(QuantizeWhereModelTest, QuantizeWhere) { @@ -1976,8 +1994,8 @@ class QuantizeResourcesModelTest public testing::WithParamInterface { protected: QuantizeResourcesModelTest() - : QuantizeModelTest( - ReadModel(internal::kModelWithResourceVarsCalibrated)) { + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kModelWithResourceVarsCalibrated)) { TestType obj = GetParam(); tensor_type_ = obj.tensor_type; modify_range_ = obj.modify_range; @@ -2119,7 +2137,8 @@ class QuantizeConcatConstModelTest public testing::WithParamInterface { protected: QuantizeConcatConstModelTest() - : QuantizeModelTest(ReadModel(internal::kFloatConcatMax5Max10Max10)) { + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kFloatConcatMax5Max10Max10)) { // Make one of the values constant. MakeInputConstant(&model_); } @@ -2224,7 +2243,8 @@ class BiasInputTest : public QuantizeModelTest, public testing::WithParamInterface { protected: BiasInputTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)) { + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights)) { BiasTestType obj = GetParam(); tensor_type_ = obj.tensor_type; bias_type_ = obj.bias_type; diff --git a/tensorflow/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/lite/tools/optimize/quantize_weights_test.cc index 0e9c3efc17acd9..b2279ed34908f6 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -22,13 +22,13 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/core/model.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace { tensorflow::string* g_test_model_dir = nullptr; @@ -40,25 +40,25 @@ namespace { std::unique_ptr ReadTestModel() { auto model_path = tensorflow::io::JoinPath( - *g_test_model_dir, internal::kConvModelWith0Plus10Weights); + *g_test_model_dir, ::mlir::lite::internal::kConvModelWith0Plus10Weights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadSharedWeightsTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kModelWithSharedWeights); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithSharedWeights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadGatherTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kQuantizedWithGather); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kQuantizedWithGather); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadCustomOpTestModel() { - auto model_path = - tensorflow::io::JoinPath(*g_test_model_dir, internal::kModelWithCustomOp); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithCustomOp); return FlatBufferModel::BuildFromFile(model_path.c_str()); } diff --git a/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc b/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc index 6b5cf538b50c43..19400079b17e96 100644 --- a/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc +++ b/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace tflite { namespace optimize { From 5ca783cfc05c02fa363773be9f9ac900e1d5dc27 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 25 Jun 2024 09:31:03 -0700 Subject: [PATCH 4/4] [xla:cpu] Move BufferAllocations implementation to header file Resolving buffer slice device memory is on a critical path of every thunk. Move implementation to header and force inlining to improve performance of ultra small kernels. PiperOrigin-RevId: 646506658 --- .../select_and_scatter_benchmark_test.cc | 3 +- third_party/xla/xla/service/cpu/runtime/BUILD | 19 ++++- .../service/cpu/runtime/buffer_allocations.cc | 76 ----------------- .../service/cpu/runtime/buffer_allocations.h | 85 ++++++++++++++++--- .../cpu/runtime/buffer_allocations_test.cc | 53 ++++++++++++ .../xla/service/cpu/runtime/kernel_thunk.cc | 30 ++++--- .../xla/stream_executor/host/host_kernel.cc | 16 ++-- .../xla/stream_executor/host/host_kernel.h | 2 + 8 files changed, 172 insertions(+), 112 deletions(-) delete mode 100644 third_party/xla/xla/service/cpu/runtime/buffer_allocations.cc create mode 100644 third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc diff --git a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc index bbc32250444b0f..b03df23e41e764 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc @@ -76,7 +76,6 @@ BENCHMARK(BM_SelectAndScatterF32) ->Arg(64) ->Arg(128) ->Arg(256) - ->Arg(512) - ->Arg(1024); + ->Arg(512); } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 8c37031edf7985..f9ad7ef5c51300 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -17,21 +17,36 @@ package_group( cc_library( name = "buffer_allocations", - srcs = ["buffer_allocations.cc"], hdrs = ["buffer_allocations.h"], deps = [ + "//xla:util", "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:statusor", ], ) +xla_cc_test( + name = "buffer_allocations_test", + srcs = ["buffer_allocations_test.cc"], + deps = [ + ":buffer_allocations", + "//xla/service:buffer_assignment", + "//xla/service:maybe_owning_device_memory", + "//xla/stream_executor:device_memory", + "@com_google_absl//absl/status", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "task", hdrs = ["task.h"], diff --git a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.cc b/third_party/xla/xla/service/cpu/runtime/buffer_allocations.cc deleted file mode 100644 index e35b931c08e5bc..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/cpu/runtime/buffer_allocations.h" - -#include - -#include "absl/base/optimization.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "xla/service/buffer_assignment.h" -#include "xla/stream_executor/device_memory.h" -#include "tsl/platform/statusor.h" - -namespace xla::cpu { - -absl::StatusOr BufferAllocations::GetDeviceAddress( - BufferAllocation::Index buffer_index) const { - if (ABSL_PREDICT_FALSE(buffer_index < 0 || buffer_index >= buffers_.size())) { - return absl::InvalidArgumentError(absl::StrCat( - "Invalid buffer_index ", buffer_index, - " value. It must be in the range [0, ", buffers_.size(), ")")); - } - - return buffers_[buffer_index].AsDeviceMemoryBase(); -} - -absl::StatusOr BufferAllocations::GetDeviceAddress( - const BufferAllocation::Slice& buffer_slice) const { - // Handle empty slices explicitly and return a null pointer device memory to - // guarantee that we do not accidentally write through the empty slice which - // would hide a real bug in the code. - if (ABSL_PREDICT_FALSE(buffer_slice.size() == 0)) { - return se::DeviceMemoryBase(nullptr, 0); - } - - int64_t index = buffer_slice.index(); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase base, GetDeviceAddress(index)); - - int64_t offset = buffer_slice.offset(); - int64_t extent = offset + buffer_slice.size(); - - if (ABSL_PREDICT_FALSE(offset < 0)) { - return absl::InvalidArgumentError( - absl::StrCat("Buffer slice offset ", offset, " must be non-negative")); - } - - if (ABSL_PREDICT_FALSE(offset >= base.size())) { - return absl::InvalidArgumentError(absl::StrCat( - "Buffer slice offset ", offset, " is out of range for buffer #", index, - " of size ", base.size())); - } - - if (ABSL_PREDICT_FALSE(extent > base.size())) { - return absl::InvalidArgumentError(absl::StrCat( - "Buffer slice extent ", extent, " is out of range for buffer #", index, - " of size ", base.size())); - } - - return base.GetByteSlice(offset, buffer_slice.size()); -} - -} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h b/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h index 76f05390a01b07..7abcff73fb5b66 100644 --- a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h +++ b/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h @@ -16,39 +16,102 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ #define XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/optimization.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/stream_executor/device_memory.h" +#include "xla/util.h" namespace xla::cpu { // Buffer allocation is a container for device buffers allocated for a // particular XLA execution. Buffers are indexed by the buffer allocation index. -// -// TODO(b/342513610): BufferAllocations should be unified with a same class in -// the XLA:GPU runtime, probably as a part of `buffer_assignment.h`. class BufferAllocations { public: - explicit BufferAllocations(absl::Span buffers) - : buffers_(buffers) {} + explicit inline BufferAllocations( + absl::Span buffers); // Returns the device address of buffer `buffer_index`. `buffer_index` must be // a valid index, i.e., in [0, buffer_count). - absl::StatusOr GetDeviceAddress( - BufferAllocation::Index buffer_index) const; + inline ABSL_ATTRIBUTE_ALWAYS_INLINE absl::StatusOr + GetDeviceAddress(BufferAllocation::Index buffer_index) const; // Same as above, but also adjusts the returned address for the offset and // size contained in the given slice. - absl::StatusOr GetDeviceAddress( - const BufferAllocation::Slice& buffer_slice) const; + inline ABSL_ATTRIBUTE_ALWAYS_INLINE absl::StatusOr + GetDeviceAddress(const BufferAllocation::Slice& buffer_slice) const; private: - // TODO(ezhulenev): Make BufferAllocations an owner of the buffers. - absl::Span buffers_; // not owned + std::vector buffers_; + size_t num_buffers_; }; +BufferAllocations::BufferAllocations( + absl::Span buffers) + : buffers_(buffers.size()), num_buffers_(buffers_.size()) { + for (size_t i = 0; i < buffers.size(); ++i) { + buffers_[i] = buffers[i].AsDeviceMemoryBase(); + } +} + +absl::StatusOr BufferAllocations::GetDeviceAddress( + BufferAllocation::Index index) const { + if (ABSL_PREDICT_FALSE(index < 0 || index >= num_buffers_)) { + return InvalidArgument( + "Invalid buffer index %d. It must be in the range [0, %d)", index, + num_buffers_); + } + + return buffers_[index]; +} + +absl::StatusOr BufferAllocations::GetDeviceAddress( + const BufferAllocation::Slice& buffer_slice) const { + // Handle empty slices explicitly and return a null pointer device memory to + // guarantee that we do not accidentally write through the empty slice which + // would hide a real bug in the code. + if (ABSL_PREDICT_FALSE(buffer_slice.size() == 0)) { + return se::DeviceMemoryBase(nullptr, 0); + } + + int64_t index = buffer_slice.index(); + if (ABSL_PREDICT_FALSE(index < 0 || index >= num_buffers_)) { + return InvalidArgument( + "Invalid buffer index %d. It must be in the range [0, %d)", index, + num_buffers_); + } + const se::DeviceMemoryBase& base = buffers_[index]; + + int64_t offset = buffer_slice.offset(); + int64_t extent = offset + buffer_slice.size(); + + if (ABSL_PREDICT_FALSE(offset < 0)) { + return InvalidArgument("Buffer slice offset %d must be non-negative", + offset); + } + + if (ABSL_PREDICT_FALSE(offset >= base.size())) { + return InvalidArgument( + "Buffer slice offset %d is out of range for buffer #%d of size %d", + offset, index, base.size()); + } + + if (ABSL_PREDICT_FALSE(extent > base.size())) { + return InvalidArgument( + "Buffer slice extent %d is out of range for buffer #%d of size %d", + extent, index, base.size()); + } + + return base.GetByteSlice(offset, buffer_slice.size()); +} + } // namespace xla::cpu #endif // XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc b/third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc new file mode 100644 index 00000000000000..f281924e2542ac --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc @@ -0,0 +1,53 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/buffer_allocations.h" + +#include +#include + +#include "xla/service/buffer_assignment.h" +#include "xla/service/maybe_owning_device_memory.h" +#include "xla/stream_executor/device_memory.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +TEST(BufferAllocationsTest, GetDeviceAddress) { + std::vector buffers; + std::vector data = {1.0, 2.0, 3.0, 4.0}; + + size_t size_in_bytes = data.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(data.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + + BufferAllocation alloc(0, size_in_bytes, 0); + BufferAllocation::Slice slice(&alloc, /*offset=*/2 * sizeof(float), + /*size=*/sizeof(float)); + + TF_ASSERT_OK_AND_ASSIGN(se::DeviceMemoryBase alloc_mem, + allocations.GetDeviceAddress(0)); + EXPECT_EQ(alloc_mem.opaque(), &data[0]); + + TF_ASSERT_OK_AND_ASSIGN(se::DeviceMemoryBase slice_mem, + allocations.GetDeviceAddress(slice)); + EXPECT_EQ(slice_mem.opaque(), &data[2]); +} + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc index a8d793d1076071..21c57fef35f940 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc @@ -87,40 +87,46 @@ tsl::AsyncValueRef KernelThunk::Execute( kernel_name_, arguments_buffers_.size(), results_buffers_.size(), thread_dim_.ToString()); - absl::InlinedVector kernel_args; - kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size()); + int64_t num_args = arguments_buffers_.size() + results_buffers_.size(); + absl::InlinedVector kernel_args(num_args); + + // We initialize `kernel_args` array using pointer to the first argument, + // because individual elements access adds up measurable overhead, and this + // code is on the critical path. + SE_HOST_KernelArg* kernel_args_ptr = kernel_args.data(); + int64_t kernel_arg_idx = 0; int64_t arg_num = 0; for (BufferAllocation::Slice& buffer : arguments_buffers_) { TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data, params.buffer_allocations->GetDeviceAddress(buffer)); - kernel_args.push_back( - SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()}); VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++, - buffer.ToString(), kernel_args.back().data); + buffer.ToString(), arg_data.opaque()); + kernel_args_ptr[kernel_arg_idx++] = + SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()}; } int64_t res_num = 0; for (BufferAllocation::Slice& buffer : results_buffers_) { TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data, params.buffer_allocations->GetDeviceAddress(buffer)); - kernel_args.push_back( - SE_HOST_KernelArg{result_data.opaque(), result_data.size()}); VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++, - buffer.ToString(), kernel_args.back().data); + buffer.ToString(), result_data.opaque()); + kernel_args_ptr[kernel_arg_idx++] = + SE_HOST_KernelArg{result_data.opaque(), result_data.size()}; } // Check that all buffers are aligned to the minimum alignment. We codegen // with the assumption that all buffers are aligned, and if they are not, we // will crash with a segmentation fault, or worse, produce incorrect results. if (min_alignment_.has_value()) { - for (int64_t i = 0; i < kernel_args.size(); ++i) { - auto ptr = reinterpret_cast(kernel_args[i].data); + for (int64_t i = 0; i < num_args; ++i) { + auto ptr = reinterpret_cast(kernel_args_ptr[i].data); if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) { return Internal( "Host kernel %s buffer argument #%d (%p) is not aligned to a " "required minimum alignment of %d bytes", - info().op_name, i, kernel_args[i].data, *min_alignment_); + info().op_name, i, kernel_args_ptr[i].data, *min_alignment_); } } } @@ -136,7 +142,7 @@ tsl::AsyncValueRef KernelThunk::Execute( params.host_kernels->Find(kernel_name_)); absl::MutexLock lock(&mutex_); - kernel_.emplace(kernel_args.size(), kernel_fn, nullptr); + kernel_.emplace(num_args, kernel_fn, nullptr); kernel_ptr_.store(kernel = &kernel_.value()); } diff --git a/third_party/xla/xla/stream_executor/host/host_kernel.cc b/third_party/xla/xla/stream_executor/host/host_kernel.cc index 04586b5272432b..cad37e1bfa4fb0 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel.cc +++ b/third_party/xla/xla/stream_executor/host/host_kernel.cc @@ -67,8 +67,7 @@ class HostKernelExecuteState : public tsl::ReferenceCounted { public: HostKernelExecuteState(HostKernel::TaskRunner task_runner, - HostKernel::KernelFunction* function, - ThreadDim thread_dims, + SE_HOST_Kernel* kernel, ThreadDim thread_dims, absl::Span args); // Notify of a completion of a host kernel task. @@ -112,6 +111,7 @@ HostKernel::HostKernel(std::shared_ptr thread_pool) HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel, std::shared_ptr thread_pool) : function_(std::make_unique(kernel)), + kernel_(function_->kernel()), arity_(arity), thread_pool_(thread_pool) {} @@ -130,8 +130,6 @@ absl::Status HostKernel::Launch( thread_dims.z, }; - SE_HOST_Kernel* kernel = function_->kernel(); - for (uint64_t z = 0; z < thread_dims.z; ++z) { for (uint64_t y = 0; y < thread_dims.y; ++y) { for (uint64_t x = 0; x < thread_dims.x; ++x) { @@ -140,7 +138,7 @@ absl::Status HostKernel::Launch( SE_HOST_KernelCallFrame call_frame = { &kernel_thread_dims, &kernel_thread, args.size(), args.data()}; - SE_HOST_KernelError* error = (*kernel)(&call_frame); + SE_HOST_KernelError* error = (*kernel_)(&call_frame); if (ABSL_PREDICT_FALSE(error != nullptr)) { return absl::InternalError("Failed to call host kernel"); @@ -174,8 +172,8 @@ tsl::AsyncValueRef HostKernel::Launch( } // Allocate a control structure that will orchestrate kernel execution. - auto state = tsl::MakeRef( - std::move(task_runner), function_.get(), thread_dims, args); + auto state = tsl::MakeRef(std::move(task_runner), + kernel_, thread_dims, args); state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks); @@ -183,11 +181,11 @@ tsl::AsyncValueRef HostKernel::Launch( } HostKernelExecuteState::HostKernelExecuteState( - HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function, + HostKernel::TaskRunner task_runner, SE_HOST_Kernel kernel, ThreadDim thread_dims, absl::Span args) : task_runner_(std::move(task_runner)), num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z), - kernel_(function->kernel()), + kernel_(kernel), thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}), args_(args.begin(), args.end()), abort_(false), diff --git a/third_party/xla/xla/stream_executor/host/host_kernel.h b/third_party/xla/xla/stream_executor/host/host_kernel.h index 9d278b2b79c357..9bc96cb9e7ca2a 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel.h +++ b/third_party/xla/xla/stream_executor/host/host_kernel.h @@ -113,10 +113,12 @@ class HostKernel : public Kernel { std::enable_if_t>* = nullptr> void SetKernelFunction(std::unique_ptr function) { function_ = std::move(function); + kernel_ = function_->kernel(); } private: std::unique_ptr function_; + SE_HOST_Kernel* kernel_; // pointer to the kernel owned by `function_` unsigned arity_; std::shared_ptr thread_pool_;