[go: nahoru, domu]

Skip to content

Commit

Permalink
Provide a runtime option to lower bound the number of batch threads.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626118642
  • Loading branch information
deqiangc authored and tensorflower-gardener committed Apr 18, 2024
1 parent 762375d commit 45f30ac
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 2 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/tfrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ cc_library(
"transforms/deduplicate_if_result_pass.cc",
"transforms/fuse_tpu_compile_and_execute_ops.cc",
"transforms/insert_tensor_copy.cc",
"transforms/lower_bound_batch_threads.cc",
"transforms/lower_saved_model.cc",
"transforms/merge_tf_if_ops.cc",
"transforms/optimize.cc",
Expand Down
53 changes: 53 additions & 0 deletions tensorflow/compiler/mlir/tfrt/tests/lower_bound_batch_threads.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// RUN: tf-tfrt-opt -split-input-file -tfrt-lower-bound-batch-threads="tfrt-min-num-batch-threads=2" %s | FileCheck %s --dump-input=always

// -----

// The num_batch_threads is lowered bound to 2 from the original attribute of 1

// CHECK-LABEL: func private @batched_function
func.func private @batched_function(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> {
%2 = "tf.Identity"(%arg0) : (tensor<1x3xf32>) -> tensor<1x3xf32>
func.return %2 : tensor<1x3xf32>
}

// CHECK-LABEL: func @main
func.func @main(%arg0: tensor<1x3xf32>) -> tensor<*xf32> {
// CHECK: "tf.BatchFunction"
// CHECK-SAME: allowed_batch_sizes = [6]
// CHECK-SAME: batch_timeout_micros = 100000 : i64
// CHECK-SAME: batching_queue = ""
// CHECK-SAME: container = ""
// CHECK-SAME: enable_large_batch_splitting = false
// CHECK-SAME: max_batch_size = 6 : i64
// CHECK-SAME: max_enqueued_batches = 10 : i64
// CHECK-SAME: num_batch_threads = 2 : i64
// CHECK-SAME: shared_name = "batch/"
%1 = "tf.BatchFunction"(%arg0) {allowed_batch_sizes = [6], batch_timeout_micros = 100000 : i64, batching_queue = "", container = "", device = "/device:CPU:0", enable_large_batch_splitting = false, f = @batched_function, max_batch_size = 6 : i64, max_enqueued_batches = 10 : i64, num_batch_threads = 1 : i64, operandSegmentSizes = array<i32: 1, 0>, shared_name = "batch/"} : (tensor<1x3xf32>) -> tensor<*xf32>
func.return %1 : tensor<*xf32>
}

// -----

// The num_batch_threads remains 3 (the same as the original attribute)

// CHECK-LABEL: func private @batched_function
func.func private @batched_function(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> {
%2 = "tf.Identity"(%arg0) : (tensor<1x3xf32>) -> tensor<1x3xf32>
func.return %2 : tensor<1x3xf32>
}

// CHECK-LABEL: func @main
func.func @main(%arg0: tensor<1x3xf32>) -> tensor<*xf32> {
// CHECK: "tf.BatchFunction"
// CHECK-SAME: allowed_batch_sizes = [6]
// CHECK-SAME: batch_timeout_micros = 100000 : i64
// CHECK-SAME: batching_queue = ""
// CHECK-SAME: container = ""
// CHECK-SAME: enable_large_batch_splitting = false
// CHECK-SAME: max_batch_size = 6 : i64
// CHECK-SAME: max_enqueued_batches = 10 : i64
// CHECK-SAME: num_batch_threads = 3 : i64
// CHECK-SAME: shared_name = "batch/"
%1 = "tf.BatchFunction"(%arg0) {allowed_batch_sizes = [6], batch_timeout_micros = 100000 : i64, batching_queue = "", container = "", device = "/device:CPU:0", enable_large_batch_splitting = false, f = @batched_function, max_batch_size = 6 : i64, max_enqueued_batches = 10 : i64, num_batch_threads = 3 : i64, operandSegmentSizes = array<i32: 1, 0>, shared_name = "batch/"} : (tensor<1x3xf32>) -> tensor<*xf32>
func.return %1 : tensor<*xf32>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <algorithm>
#include <cstdint>
#include <memory>

#include "llvm/ADT/APInt.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/TypeID.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"

namespace tensorflow {
namespace tfrt_compiler {
namespace {

class LowerBoundBatchThreadsPass
: public mlir::PassWrapper<LowerBoundBatchThreadsPass,
mlir::OperationPass<mlir::ModuleOp>> {
public:
explicit LowerBoundBatchThreadsPass(uint64_t min_num_batch_threads)
: mlir::PassWrapper<LowerBoundBatchThreadsPass,
mlir::OperationPass<mlir::ModuleOp>>() {
min_num_batch_threads_ = min_num_batch_threads;
}
LowerBoundBatchThreadsPass()
: mlir::PassWrapper<LowerBoundBatchThreadsPass,
mlir::OperationPass<mlir::ModuleOp>>() {}
LowerBoundBatchThreadsPass(const LowerBoundBatchThreadsPass& other)
: mlir::PassWrapper<LowerBoundBatchThreadsPass,
mlir::OperationPass<mlir::ModuleOp>>(other) {}

LowerBoundBatchThreadsPass& operator=(
const LowerBoundBatchThreadsPass& other) = delete;

MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerBoundBatchThreadsPass)

private:
llvm::StringRef getArgument() const final {
return "tfrt-lower-bound-batch-threads";
}

llvm::StringRef getDescription() const final {
return "Lower bound batch threads for batch ops.";
}

void runOnOperation() override {
if (min_num_batch_threads_ > 0) {
mlir::ModuleOp module = getOperation();
module.walk([&](mlir::TF::BatchFunctionOp batch_op) {
int64_t num_batch_threads = batch_op.getNumBatchThreads();
num_batch_threads =
std::max(num_batch_threads, min_num_batch_threads_.getValue());
batch_op.setNumBatchThreads(num_batch_threads);
});
}
}

protected:
mlir::Pass::Option<int64_t> min_num_batch_threads_{
*this, "tfrt-min-num-batch-threads", llvm::cl::init(1),
llvm::cl::desc("Minimum number of batch threads")};
;
};

} // namespace

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateLowerBoundBatchThreadsPass(int64_t min_num_batch_threads) {
return std::make_unique<LowerBoundBatchThreadsPass>(min_num_batch_threads);
}

static mlir::PassRegistration<LowerBoundBatchThreadsPass> register_pass;

} // namespace tfrt_compiler
} // namespace tensorflow
4 changes: 4 additions & 0 deletions tensorflow/compiler/mlir/tfrt/transforms/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper(
// Merge non-side-effecting tf.If ops if their operands are the same.
pm.addPass(tfrt_compiler::CreateMergeTfIfOpsPass());

// Lower bound on the number of batch threads in `tf.BatchFunction`.
pm.addPass(tfrt_compiler::CreateLowerBoundBatchThreadsPass(
options.min_num_batch_threads));

// Deduplicate functions invoked by tf.BatchFunction with the same
// shared_name
pm.addPass(
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/mlir/tfrt/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> CreateMergeTfIfOpsPass();
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDeduplicateFunctionsInovkedByBatchFunctionPass();

// Create a pass to lower bound the number of threads in tf.BatchFunction.
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateLowerBoundBatchThreadsPass(int64_t min_num_batch_threads);

// Create a pass to fuse the TPU Ops for TFRT.
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateFuseTpuCompileAndExecutePass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_PIPELINE_OPTIONS_H_
#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_PIPELINE_OPTIONS_H_

#include <cstdint>
#include <string>

#include "llvm/Support/CommandLine.h"
Expand Down Expand Up @@ -144,6 +145,10 @@ struct TfrtPipelineOptions
"cheap, and then whether it can be executed inline."),
llvm::cl::init(1)};

Option<int64_t> min_num_batch_threads{
*this, "tfrt-min-num-batch-threads",
llvm::cl::desc("The minimum number of batch threads"), llvm::cl::init(1)};

Option<bool> merge_inter_dependent_streams{
*this, "tfrt-merge-inter-dependent-streams",
llvm::cl::desc("If true, streams with inter data depenedencies will be "
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/tfrt/translate/import_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ std::unique_ptr<tensorflow::TfrtPipelineOptions> GetTfrtPipelineOptions(
pipeline_options->enable_while_parallel_iterations =
options.enable_while_parallel_iterations;
pipeline_options->cost_threshold = options.cost_threshold;
pipeline_options->min_num_batch_threads = options.min_num_batch_threads;

pipeline_options->merge_inter_dependent_streams =
options.merge_inter_dependent_streams;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ std::ostream& operator<<(std::ostream& os,
}

std::ostream& operator<<(std::ostream& os, const TfrtCompileOptions& options) {
return os << "{"
<< "variable_device = " << options.variable_device
return os << "{" << "variable_device = " << options.variable_device
<< ", default_device = " << options.default_device
<< ", enable_optimizer = " << options.enable_optimizer
<< ", enable_grappler = " << options.enable_grappler
Expand All @@ -58,6 +57,7 @@ std::ostream& operator<<(std::ostream& os, const TfrtCompileOptions& options) {
<< ", enable_while_parallel_iterations = "
<< options.enable_while_parallel_iterations
<< ", cost_threshold = " << options.cost_threshold
<< ", min_num_batch_threads = " << options.min_num_batch_threads
<< ", merge_inter_dependent_streams = "
<< options.merge_inter_dependent_streams
<< ", decompose_resource_ops = " << options.decompose_resource_ops
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ struct TfrtCompileOptions {
// expensive.
uint64_t cost_threshold = 1;

// The minimum number of batch threads. This number provides a lower bound on
// the number of batch threads on top of what is specified in the model. If
// the number of batch threads is too small (e.g. smaller than the number of
// parallel hardware accelerator available), it can lead to under utilization
// of resources.
int64_t min_num_batch_threads = 1;

// If true, streams with inter data depenedencies will be preferred to be
// merged for inline execution.
bool merge_inter_dependent_streams = true;
Expand Down

0 comments on commit 45f30ac

Please sign in to comment.