diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index 86c578a33b5600..906bc745c7738d 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -32,6 +32,30 @@ cc_library( ], ) +cc_library( + name = "batch_stats", + hdrs = ["batch_stats.h"], + deps = [ + "//tensorflow/core:framework_lite", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + ], +) + +tf_cc_test( + name = "batch_stats_test", + srcs = ["batch_stats_test.cc"], + deps = [ + ":batch_stats", + "//tensorflow/core:test", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "batch_input_task", hdrs = ["batch_input_task.h"], @@ -100,6 +124,7 @@ cc_library( "//tensorflow/core/lib/core:status", "//tensorflow/core/platform:thread_annotations", "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:criticality", @@ -113,6 +138,7 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -126,7 +152,12 @@ cc_library( srcs = ["batch_scheduler_utils.cc"], hdrs = ["batch_scheduler_utils.h"], deps = [ + ":batch_scheduler_hdrs", "//tensorflow/core:portable_gif_internal", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", ], ) @@ -160,7 +191,9 @@ tf_cc_test( name = "batch_scheduler_utils_test", srcs = ["batch_scheduler_utils_test.cc"], deps = [ + ":batch_scheduler_hdrs", ":batch_scheduler_utils", + "@com_google_absl//absl/flags:flag", "@com_google_googletest//:gtest_main", ], ) @@ -402,6 +435,7 @@ cc_library( ":adaptive_shared_batch_scheduler", ":batch_scheduler", ":batch_scheduler_utils", + ":batch_stats", ":concat_split_util", ":input_split_metadata", ":shared_batch_scheduler", @@ -457,12 +491,15 @@ tf_cc_test( srcs = ["batch_resource_base_test.cc"], deps = [ ":batch_resource_base", + ":batch_stats", "//tensorflow/core:framework", + "//tensorflow/core/common_runtime:cost_constants", "//tensorflow/core/common_runtime:cost_measurement", "//tensorflow/core/common_runtime:cost_measurement_registry", "//tensorflow/core/common_runtime:no_op_cost_measurement", "//tensorflow/core/common_runtime:request_cost", "//tensorflow/core/framework:types_proto_cc", + "@com_google_absl//absl/random", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index fde69d27b0c6bd..8807889e911979 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -53,6 +53,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/kernels/batching_util/batch_scheduler.h" #include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h" +#include "tensorflow/core/kernels/batching_util/batch_stats.h" #include "tensorflow/core/kernels/batching_util/concat_split_util.h" #include "tensorflow/core/kernels/batching_util/input_split_metadata.h" #include "tensorflow/core/kernels/batching_util/threadsafe_status.h" @@ -1235,6 +1236,15 @@ void BatchResourceBase::SplitBatchCostsAndRecordMetrics( absl::StrCat(cost_type, kNoSmearSuffix), total_cost / processed_size * batch.size()); + if (cost_type == kTpuCostName) { + // Register TPU cost for in-process use. + GlobalBatchStats() + .model(model_name) + .batch_size(processed_size) + .tpu_cost() + .Register(total_cost); + } + for (int i = 0; i < batch.num_tasks(); i++) { RequestCost* request_cost = batch.task(i).request_cost; // Skip recording the cost if the request_cost is null. diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc index 7aecac23691a4d..1c5b4059d0e566 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc @@ -23,12 +23,14 @@ limitations under the License. #include #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "tensorflow/core/common_runtime/cost_constants.h" #include "tensorflow/core/common_runtime/cost_measurement.h" #include "tensorflow/core/common_runtime/cost_measurement_registry.h" #include "tensorflow/core/common_runtime/request_cost.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/batching_util/batch_stats.h" #include "tsl/platform/criticality.h" namespace tensorflow { @@ -294,6 +296,39 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SplitOnlyNonZeroCostTypes) { UnorderedElementsAre(Pair("test_tpu", absl::Milliseconds(100)))))); } +TEST(SplitBatchCostsAndRecordMetricsTest, UpdatesGlobalBatchStats) { + // Create batch_cost_measurements with one TPU cost. + class FakeTpuCostMeasurement : public CostMeasurement { + public: + using CostMeasurement::CostMeasurement; + absl::Duration GetTotalCost() override { return absl::Hours(555); } + absl::string_view GetCostType() const override { return kTpuCostName; } + }; + CostMeasurement::Context context{/* is_per_query= */ false}; + std::vector> batch_cost_measurements; + batch_cost_measurements.push_back( + std::make_unique(context)); + + // Create a non-empty batch. + BatchResourceBase::BatchT batch; + batch.AddTask(MakeBatchTask(/* task_size= */ 1, nullptr)); + batch.Close(); + + // Pick a model name that no other test would pick. This is so that we are + // sure that the CPU cost for this model name has either never been reported + // before or, if this test is executed multiple times, has been reported by + // this only. + const char kModelName[] = __FILE__; + + BatchResourceBase::SplitBatchCostsAndRecordMetrics( + kModelName, batch_cost_measurements, + /* processed_size= */ 17, batch); + + EXPECT_EQ( + GlobalBatchStats().model(kModelName).batch_size(17).tpu_cost().mean(), + absl::Hours(555)); +} + } // namespace } // namespace serving } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler.h b/tensorflow/core/kernels/batching_util/batch_scheduler.h index c70972cc2bf6b4..1d64defcefe6c6 100644 --- a/tensorflow/core/kernels/batching_util/batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/batch_scheduler.h @@ -32,18 +32,18 @@ limitations under the License. #include #include #include -#include +#include #include #include #include #include +#include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" @@ -304,6 +304,15 @@ class Batch { // Returns the TraceMe context id of this batch. uint64 traceme_context_id() const; + // Attempts to trim this batch to a new, smaller size (not to be confused with + // the number of tasks in the batch). On success, the trimmed tasks go into + // 'out_trimmed_tasks' in the same order the tasks were in this batch. + // + // The method might not succeed if it needs to split a large task to hit the + // correct size. + void TryTrimToNewSize( + int new_size, std::vector>& out_trimmed_tasks); + private: mutable mutex mu_; @@ -505,6 +514,42 @@ uint64 Batch::traceme_context_id() const { return traceme_context_id_; } +template +void Batch::TryTrimToNewSize( + int new_size, std::vector>& out_trimmed_tasks) { + mutex_lock l(mu_); + DCHECK_GT(new_size, 0); + DCHECK_LT(new_size, size_); + DCHECK(out_trimmed_tasks.empty()); + + // Index of the first task to trim away. It is possible that it is the index + // of a task of size larger than 1 that will have to be split in order to get + // to the target new_size. + int32 first_task_to_move = 0; + // The sum of sizes of tasks i, where i < first_task_to_move. + int32 size_of_previous_tasks = 0; + while (size_of_previous_tasks + tasks_[first_task_to_move]->size() <= + new_size) { + size_of_previous_tasks += tasks_[first_task_to_move]->size(); + first_task_to_move++; + } + + // Check whether task 'first_task_to_move' will have to be split. + if (size_of_previous_tasks < new_size) { + // TODO: b/325954758 - Consider supporting splitting large tasks and then + // drop 'Try' from the method name. + return; + } + DCHECK_EQ(size_of_previous_tasks, new_size); + + // Actually trim. + out_trimmed_tasks.reserve(tasks_.size() - first_task_to_move); + std::move(tasks_.begin() + first_task_to_move, tasks_.end(), + std::back_inserter(out_trimmed_tasks)); + tasks_.resize(first_task_to_move); + size_ = new_size; +} + } // namespace serving } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc index e159c4373fdf90..2f9c9031776373 100644 --- a/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc @@ -21,12 +21,13 @@ limitations under the License. #include #include #include +#include #include #include #include "absl/status/status.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/status_matchers.h" #include "tensorflow/core/platform/test.h" #include "tsl/platform/criticality.h" @@ -37,6 +38,7 @@ namespace { using ::testing::ElementsAre; using ::testing::Eq; +using ::testing::Pointer; using ::testing::Property; TEST(MixedPriorityBatchingPolicyTest, InvalidAttrValueError) { @@ -386,6 +388,53 @@ TEST(BatchTest, RemoveAllTasks) { EXPECT_THAT(batch.RemoveAllTasks(), ::testing::IsEmpty()); // third call } +TEST(BatchTest, TryTrimToNewSizeTrimsAndReturnsTrimmedElementsInOrder) { + Batch batch; + + auto task0 = new FakeTask(3); + batch.AddTask(std::unique_ptr(task0)); + + auto task1 = new FakeTask(5); + batch.AddTask(std::unique_ptr(task1)); + + auto task2 = new FakeTask(7); + batch.AddTask(std::unique_ptr(task2)); + + auto task3 = new FakeTask(9); + batch.AddTask(std::unique_ptr(task3)); + + std::vector> trimmed_tasks; + batch.TryTrimToNewSize(/* new_size= */ 8, + /* out_trimmed_tasks= */ trimmed_tasks); + + EXPECT_EQ(batch.size(), 8); + EXPECT_EQ(batch.num_tasks(), 2); + + EXPECT_THAT(trimmed_tasks, ElementsAre(Pointer(task2), Pointer(task3))); + + batch.Close(); // Batch::~Batch blocks until the batch is closed. +} + +TEST(BatchTest, TryTrimToNewSizeDoesNotTrimWhenItWouldNeedToSplitATask) { + Batch batch; + + auto task0 = new FakeTask(3); + batch.AddTask(std::unique_ptr(task0)); + + auto task1 = new FakeTask(5); + batch.AddTask(std::unique_ptr(task1)); + + std::vector> trimmed_tasks; + batch.TryTrimToNewSize(/* new_size= */ 4, + /* out_trimmed_tasks= */ trimmed_tasks); + + EXPECT_EQ(batch.size(), 8); + EXPECT_EQ(batch.num_tasks(), 2); + EXPECT_TRUE(trimmed_tasks.empty()); + + batch.Close(); // Batch::~Batch blocks until the batch is closed. +} + } // namespace } // namespace serving } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_utils.cc b/tensorflow/core/kernels/batching_util/batch_scheduler_utils.cc index 148bb7c039f2dc..aa97e4fc218439 100644 --- a/tensorflow/core/kernels/batching_util/batch_scheduler_utils.cc +++ b/tensorflow/core/kernels/batching_util/batch_scheduler_utils.cc @@ -15,11 +15,30 @@ limitations under the License. #include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h" +#include +#include #include +#include "absl/algorithm/container.h" +#include "absl/flags/flag.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +ABSL_FLAG(tensorflow::serving::BatchPaddingPolicy, + tensorflow_batch_padding_policy, + tensorflow::serving::BatchPaddingPolicy::kPadUp, + "The policy that a batch schduler is using when deciding what to do " + "when, say, 18 requests need to be batched, but only 16 and 32 batch " + "sizes are allowed. The following options are available. PAD_UP: pad " + "to size 32. BATCH_DOWN: schedule a batch of size 16 and leave 2 " + "requests in the batch buffer. MINIMIZE_TPU_COST_PER_REQUEST: a " + "smarter greedy policy that chooses to either PAD_UP or BATCH_DOWN " + "so as to minimize the TPU costs per real request. In this case, it " + "would compare (batch_16_cost / 16) and (batch_32_cost / 18). " + "WARNING: not all batch schedulers might support this option."); + namespace tensorflow { namespace serving { @@ -40,5 +59,58 @@ int GetNextAllowedBatchSize(int batch_size, return batch_size; } +int32 GetPrevAllowedBatchSize(int batch_size, + const std::vector& allowed_batch_sizes, + bool disable_padding) { + if (disable_padding || allowed_batch_sizes.empty()) { + return batch_size; + } + + DCHECK(absl::c_is_sorted(allowed_batch_sizes)); + DCHECK_GT(batch_size, 0); + + // First from the end allowed batch size not larger than batch_size. + auto result = std::find_if( + allowed_batch_sizes.rbegin(), allowed_batch_sizes.rend(), + [&](int allowed_size) { return allowed_size <= batch_size; }); + + if (result == allowed_batch_sizes.rend()) { + // No such element exists. + return batch_size; + } + + return *result; +} + +bool AbslParseFlag(absl::string_view text, BatchPaddingPolicy* out, + std::string* error) { + if (text == "PAD_UP") { + *out = BatchPaddingPolicy::kPadUp; + return true; + } + if (text == "BATCH_DOWN") { + *out = BatchPaddingPolicy::kBatchDown; + return true; + } + if (text == "MINIMIZE_TPU_COST_PER_REQUEST") { + *out = BatchPaddingPolicy::kMinimizeTpuCostPerRequest; + return true; + } + *error = "unrecognized batching policy string"; + return false; +} + +string AbslUnparseFlag(BatchPaddingPolicy in) { + switch (in) { + case BatchPaddingPolicy::kPadUp: + return "PAD_UP"; + case BatchPaddingPolicy::kBatchDown: + return "BATCH_DOWN"; + case BatchPaddingPolicy::kMinimizeTpuCostPerRequest: + return "MINIMIZE_TPU_COST_PER_REQUEST"; + } + CHECK(FATAL) << "Unrecognized BatchPaddingPolicy enum value."; // Crash OK +} + } // namespace serving } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h b/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h index 38831531abd6e7..f4ba26c1193524 100644 --- a/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h +++ b/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h @@ -16,10 +16,25 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_ #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_ +#include +#include #include +#include "absl/flags/declare.h" +#include "absl/flags/flag.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" #include "tensorflow/core/platform/types.h" +namespace tensorflow::serving { +enum class BatchPaddingPolicy; // Forward-declaring for the ABSL_DECLARE_FLAG. +} // namespace tensorflow::serving + +// Exposed for testing only. +ABSL_DECLARE_FLAG(tensorflow::serving::BatchPaddingPolicy, + tensorflow_batch_padding_policy); + namespace tensorflow { namespace serving { @@ -30,6 +45,64 @@ int GetNextAllowedBatchSize(int batch_size, const std::vector& allowed_batch_sizes, bool disable_padding); +// Returns the largest allowed batch size that is smaller than or equal to +// batch_size. Returns batch_size if no such size exists. +int GetPrevAllowedBatchSize(int batch_size, + const std::vector& allowed_batch_sizes, + bool disable_padding); + +// See the description of the --tensorflow_batch_padding_policy flag (in the +// .cc file) for the documentation. +enum class BatchPaddingPolicy { + kPadUp, + kBatchDown, + kMinimizeTpuCostPerRequest, +}; +bool AbslParseFlag(absl::string_view text, BatchPaddingPolicy* out, + std::string* error); +std::string AbslUnparseFlag(BatchPaddingPolicy in); + +// Trims the batch to the next allowed batch size when possible and when +// configured by the --tensorflow_batch_padding_policy flag. +// +// When trimming, this function puts the trimmed tasks go into the +// out_trimmed_tasks vector in the same order as they were in the batch. +template +void MaybeBatchDown(Batch& batch, + const std::vector& allowed_batch_sizes, + bool disable_padding, + std::vector>& out_trimmed_tasks) { + switch (absl::GetFlag(FLAGS_tensorflow_batch_padding_policy)) { + case BatchPaddingPolicy::kPadUp: + // This is the default behavior of batch resource when it is given a batch + // size that doesn't match any of the allowed batch sizes. + return; + case BatchPaddingPolicy::kBatchDown: + // Continue with this method. + break; + case BatchPaddingPolicy::kMinimizeTpuCostPerRequest: + LOG(DFATAL) << "BatchPaddingPolicy::kMinimizeTpuCostPerRequest is not " + "yet implemented, falling back on kBatchDown."; + break; + } + + int32 batch_size = batch.size(); + + int32 pad_up_size = + GetNextAllowedBatchSize(batch_size, allowed_batch_sizes, disable_padding); + if (pad_up_size == batch_size) { + return; // Good, no padding is necessary. + } + + int32 batch_down_size = + GetPrevAllowedBatchSize(batch_size, allowed_batch_sizes, disable_padding); + if (batch_down_size == batch_size) { + return; // Can't batch down (e.g. no smaller batch size available). + } + + batch.TryTrimToNewSize(batch_down_size, out_trimmed_tasks); +} + } // namespace serving } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc b/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc index 9cd6ce1ddcb210..e689566c173400 100644 --- a/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc +++ b/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc @@ -15,7 +15,14 @@ limitations under the License. #include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h" +#include +#include +#include +#include + #include +#include "absl/flags/flag.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" namespace tensorflow { namespace serving { @@ -42,6 +49,174 @@ TEST(GetNextAllowedBatchSizeTest, GreaterThanAllowedBatchSize) { EXPECT_EQ(GetNextAllowedBatchSize(10, {2, 4, 8}, false), 10); } +TEST(GetPrevAllowedBatchSizeTest, PaddingDisallowed) { + EXPECT_EQ(GetPrevAllowedBatchSize(3, {2, 4, 8}, true), 3); +} + +TEST(GetPrevAllowedBatchSizeTest, EmptyAllowedBatchSizes) { + EXPECT_EQ(GetPrevAllowedBatchSize(3, {}, false), 3); +} + +TEST(GetPrevAllowedBatchSizeTest, PrevAllowedBatchSizeFound) { + EXPECT_EQ(GetPrevAllowedBatchSize(3, {1, 2, 4, 8}, false), 2); +} + +TEST(GetPrevAllowedBatchSizeTest, NoSmallerAllowedBatchSizeFound) { + EXPECT_EQ(GetPrevAllowedBatchSize(3, {4, 8}, false), 3); +} + +TEST(GetPrevAllowedBatchSizeTest, AlreadyAllowedBatchSize) { + EXPECT_EQ(GetPrevAllowedBatchSize(2, {1, 2, 4, 8}, false), 2); +} + +TEST(GetPrevAllowedBatchSizeTest, GreaterThanMaxAllowedBatchSize) { + EXPECT_EQ(GetPrevAllowedBatchSize(10, {2, 4, 8}, false), 8); +} + +TEST(BatchPaddingPolicyTest, AbslParseFlag) { + std::string error; + BatchPaddingPolicy policy; + + EXPECT_TRUE(AbslParseFlag("PAD_UP", &policy, &error)); + EXPECT_EQ(policy, BatchPaddingPolicy::kPadUp); + EXPECT_EQ(error, ""); + + EXPECT_TRUE(AbslParseFlag("BATCH_DOWN", &policy, &error)); + EXPECT_EQ(policy, BatchPaddingPolicy::kBatchDown); + EXPECT_EQ(error, ""); + + EXPECT_TRUE(AbslParseFlag("MINIMIZE_TPU_COST_PER_REQUEST", &policy, &error)); + EXPECT_EQ(policy, BatchPaddingPolicy::kMinimizeTpuCostPerRequest); + EXPECT_EQ(error, ""); + + EXPECT_FALSE(AbslParseFlag("cucumber", &policy, &error)); + EXPECT_NE(error, ""); +} + +TEST(BatchPaddingPolicyTest, AbslUnparseFlag) { + EXPECT_EQ(AbslUnparseFlag(BatchPaddingPolicy::kPadUp), "PAD_UP"); + EXPECT_EQ(AbslUnparseFlag(BatchPaddingPolicy::kBatchDown), "BATCH_DOWN"); + EXPECT_EQ(AbslUnparseFlag(BatchPaddingPolicy::kMinimizeTpuCostPerRequest), + "MINIMIZE_TPU_COST_PER_REQUEST"); +} + +class FakeTask : public BatchTask { + public: + explicit FakeTask(size_t size) : size_(size) {} + + size_t size() const override { return size_; } + + private: + const size_t size_; +}; + +TEST(MaybeBatchDownTest, PadUp) { + absl::SetFlag(&FLAGS_tensorflow_batch_padding_policy, + BatchPaddingPolicy::kPadUp); + + Batch batch; + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.Close(); + + std::vector> out_trimmed_tasks; + + MaybeBatchDown( + /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8}, + /* disable_padding= */ false, /* out_trimmed_tasks= */ out_trimmed_tasks); + + // The batch must stay unchanged (for the batch resource to then pad it to the + // next allowed batch size, thus ending up in a pad-up behavior.) + EXPECT_EQ(batch.size(), 3); +} + +TEST(MaybeBatchDownTest, BatchDown) { + absl::SetFlag(&FLAGS_tensorflow_batch_padding_policy, + BatchPaddingPolicy::kBatchDown); + + Batch batch; + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.Close(); + + std::vector> out_trimmed_tasks; + + MaybeBatchDown( + /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8}, + /* disable_padding= */ false, /* out_trimmed_tasks= */ out_trimmed_tasks); + + // The scheduler should trim the batch to a smaller allowed size that requires + // no padding. + EXPECT_EQ(batch.size(), 2); + // The trimmed part. + EXPECT_EQ(out_trimmed_tasks.size(), 1); +} + +TEST(MaybeBatchDownTest, BatchDownDoesNotSplitTasks) { + absl::SetFlag(&FLAGS_tensorflow_batch_padding_policy, + BatchPaddingPolicy::kBatchDown); + + // Add tasks for size 3, but the second task is large and will have to be + // split if doing batch-down. + Batch batch; + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(2)); + batch.Close(); + + std::vector> out_trimmed_tasks; + + MaybeBatchDown( + /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8}, + /* disable_padding= */ false, /* out_trimmed_tasks= */ out_trimmed_tasks); + + // The batch must stay unchanged due the fact that the current implementation + // doesn's support splitting large tasks. + EXPECT_EQ(batch.size(), 3); +} + +TEST(MaybeBatchDownTest, BatchDownDoesNothingWhenTheBatchSizeIsAlreadyAllowed) { + absl::SetFlag(&FLAGS_tensorflow_batch_padding_policy, + BatchPaddingPolicy::kBatchDown); + + Batch batch; + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.Close(); + + std::vector> out_trimmed_tasks; + + MaybeBatchDown( + /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8}, + /* disable_padding= */ false, /* out_trimmed_tasks= */ out_trimmed_tasks); + + // The batch should stay unchanged because it's already of an allowed size. + EXPECT_EQ(batch.size(), 4); +} + +TEST(MaybeBatchDownTest, BatchDownDoesNothingWhenNoSmallerAllowedSize) { + absl::SetFlag(&FLAGS_tensorflow_batch_padding_policy, + BatchPaddingPolicy::kBatchDown); + + Batch batch; + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.Close(); + + std::vector> out_trimmed_tasks; + + MaybeBatchDown( + /* batch= */ batch, /* allowed_batch_sizes= */ {4, 8}, + /* disable_padding= */ false, /* out_trimmed_tasks= */ out_trimmed_tasks); + + // Can't batch down because there is no smaller allowed size. + EXPECT_EQ(batch.size(), 3); +} + } // namespace } // namespace serving diff --git a/tensorflow/core/kernels/batching_util/batch_stats.h b/tensorflow/core/kernels/batching_util/batch_stats.h new file mode 100644 index 00000000000000..a54fae8043f0b3 --- /dev/null +++ b/tensorflow/core/kernels/batching_util/batch_stats.h @@ -0,0 +1,172 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The API for reporting and querying batch statistics such as the average batch +// costs for in-process use. +// +// All these statistics can also be retrieved from metrics reported by various +// modules (e.g., batch_resource_base), but it would be slow. This API, on the +// other hand, was designed to be queried on every request. +// +// The classes defined here are not supposed to be instantiated by the user. +// Instead, this file provides a single entry point: +// +// BatchStats& GlobalBatchStats(); +// +// For example, to register batch cost, do: +// +// GlobalBatchStats().model("M").batch_size(4).tpu_cost.Register(cost); +// +// To get the mean cost later, do: +// +// std::optional cost = +// GlobalBatchStats().model("M").batch_size(4).tpu_cost.mean(); +// +// It is allowed and safe to store references to intermediate objects here +// because all intermediate objects are guaranteed to never be destroyed. +// +// All operations supported by this API are thread-safe. + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_STATS_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_STATS_H_ + +#include +#include +#include + +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "tensorflow/core/platform/mutex.h" +#include "tsl/platform/thread_annotations.h" + +namespace tensorflow::serving { + +// Tracks the average cost of registered samples. +// +// Thread-safe. +class CostTracker { + public: + // Registers a cost sample. + void Register(absl::Duration cost) { + DCHECK_GT(cost, absl::ZeroDuration()); + + mutex_lock l(mu_); + sample_count_++; + sample_sum_ += cost; + }; + + // Returns the average cost of all registered samples, giving each sample + // the same weight. + // + // Returns std::nullopt if no samples have been registered. + std::optional mean() const { + int64_t count; + absl::Duration sum; + + { + // We only hold the lock to read the values and release it before later + // performing a relatively slow division operation. + mutex_lock l(mu_); + count = sample_count_; + sum = sample_sum_; + } + + if (count == 0) return std::nullopt; + + return sum / count; + }; + + private: + mutable mutex mu_; + + int64_t sample_count_ TF_GUARDED_BY(mu_) = 0; + absl::Duration sample_sum_ TF_GUARDED_BY(mu_); +}; + +class BatchSizeStats { + public: + CostTracker& tpu_cost() { return tpu_cost_; }; + + private: + CostTracker tpu_cost_; +}; + +// Thread-safe. +class ModelBatchStats { + public: + // Returns a reference to the BatchSizeStats instance for the given batch + // size. + // + // The returned reference persist for as long as 'this' is alive. + BatchSizeStats& batch_size(int32_t batch_size) { + mutex_lock l(mu_); + return batch_size_stats_by_batch_size_[batch_size]; + } + + private: + mutable mutex mu_; + + // The storage of all BatchSizeStats instances. + // + // The mutex only protects adding/finding element in the map. Access to + // elements themselves (after they were created) is not protected here. No + // element deletion is possible because we return references to items in this + // map and don't track their lifetime. We are using the node hash map so that + // elements, once created, are fixed in memory. + absl::node_hash_map batch_size_stats_by_batch_size_ + TF_GUARDED_BY(mu_); +}; + +// Thread-safe. +class BatchStats { + public: + // Returns a reference to ModelBatchStats for the provided model name. + // + // Upon invocation with a not-yet-seen model_name, creates an empty + // ModelBatchStats for this model. + // + // The returned reference persist for as long as 'this' is alive. + ModelBatchStats& model(absl::string_view model_name) { + mutex_lock l(mu_); + return model_batch_stats_by_model_name_[model_name]; + } + + private: + mutable mutex mu_; + + // The storage of all ModelBatchStats instances. + // + // The mutex only protects adding/finding element in the map. Access to + // elements themselves (after they were created) is not protected here. No + // element deletion is possible because we return references to items in this + // map and don't track their lifetime. We are using the node hash map so that + // elements, once created, are fixed in memory. + absl::node_hash_map + model_batch_stats_by_model_name_ TF_GUARDED_BY(mu_); +}; + +// Returns the global instance of BatchStats, to use used for all production +// purposes (one should only instantiate individual classes from this file to +// test them). +inline BatchStats& GlobalBatchStats() { + static BatchStats* instance = new BatchStats(); + return *instance; +} + +} // namespace tensorflow::serving + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_STATS_H_ diff --git a/tensorflow/core/kernels/batching_util/batch_stats_test.cc b/tensorflow/core/kernels/batching_util/batch_stats_test.cc new file mode 100644 index 00000000000000..3897f07c3c1306 --- /dev/null +++ b/tensorflow/core/kernels/batching_util/batch_stats_test.cc @@ -0,0 +1,61 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/batching_util/batch_stats.h" + +#include +#include "absl/time/time.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow::serving { +namespace { + +TEST(BatchStatsTest, GlobalBatchStatsAlwaysReturnsTheSameInstance) { + ASSERT_EQ(&GlobalBatchStats(), &GlobalBatchStats()); +} + +TEST(BatchStatsTest, BasicOperation) { + BatchStats stats; + stats.model("a").batch_size(1).tpu_cost().Register(absl::Hours(5)); + ASSERT_EQ(stats.model("a").batch_size(1).tpu_cost().mean(), absl::Hours(5)); +} + +TEST(BatchStatsTest, ModelBatchStatsAreUniqueForEachModelName) { + BatchStats stats; + ASSERT_NE(&stats.model("a"), &stats.model("b")); +} + +TEST(BatchStatsTest, BatchSizeStatsAreUniqueForEachBatchSize) { + ModelBatchStats stats; + ASSERT_NE(&stats.batch_size(1), &stats.batch_size(2)); +} + +TEST(BatchStatsTest, CostTrackerStartsWithNoMean) { + CostTracker tracker; + + ASSERT_FALSE(tracker.mean().has_value()); +} + +TEST(BatchStatsTest, CostTrackerMeanIsCorrect) { + CostTracker tracker; + tracker.Register(absl::Hours(5)); + tracker.Register(absl::Hours(7)); + + ASSERT_EQ(*tracker.mean(), absl::Hours(6)); +} + +} // namespace + +} // namespace tensorflow::serving diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index e2598db69ae8ae..1cf5506488cc23 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -620,31 +620,13 @@ std::optional GetDescriptionForTiledTransposeEmitter( return std::nullopt; } -bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count, - const HloFusionAdaptor* fusion, - bool add_single_user_check) { +bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count) { // Number of operands should be in range [1, allowed_operand_count]. if (instr->operand_count() == 0 || instr->operand_count() > allowed_operand_count) { return false; } - if (add_single_user_check) { - // Check that intermediate `instr` doesn't have multiple users. If we have a - // fusion, only consider users within the fusion. - // TODO(akuegel): Figure out why we still need this check for transpose - // fusions. - int64_t num_users = - fusion - ? absl::c_count_if( - HloInstructionAdaptor{*instr, fusion}.GetUsers(), - [&](auto user) { return fusion->ContainsInstruction(user); }) - : instr->user_count(); - if (num_users > 1) { - return false; - } - } - if (instr->IsElementwise()) { // All elementwise ops are considered intermediate, except for copies that // modify the layout. Copies that do not modify the layout are used in @@ -684,11 +666,7 @@ static std::optional FindNonTrivialHero( return TraversalResult::kSkip; } - // We set add_single_user_check to true because it could be that it causes - // problems if we have more than one user in a transpose fusion. - // TODO(akuegel): Verify and possibly fix this. - if (!IsIntermediate(&node.instruction(), /*allowed_operand_count=*/3, - /*fusion=*/nullptr, /*add_single_user_check=*/true)) { + if (!IsIntermediate(&node.instruction(), /*allowed_operand_count=*/3)) { return TraversalResult::kSkip; } return TraversalResult::kAdvance; @@ -700,14 +678,10 @@ static std::optional FindNonTrivialHero( // Make sure that no non-elementwise op is reachable from the transpose. auto is_nontrivial = [](HloInstructionAdaptor node) { - // We set add_single_user_check to true because it could be that it causes - // problems if we have more than one user in a transpose fusion. - // TODO(akuegel): Verify and possibly fix this. return node.instruction().opcode() != HloOpcode::kTuple && node.instruction().opcode() != HloOpcode::kParameter && !IsIntermediate(&node.instruction(), - /*allowed_operand_count=*/3, /*fusion=*/nullptr, - /*add_single_user_check=*/true); + /*allowed_operand_count=*/3); }; bool visit_operands = false; if (HloAnyOf(hero->GetUsers(), hero->parent(), is_nontrivial, @@ -724,8 +698,7 @@ HloInstructionAdaptor FindNonTrivialHero(const HloInstructionAdaptor& instr) { // Go up the chain of trivial element-wise(+bitcast, -copy) operations. Note // that no memoization is needed due to number of operands constraints: we // never have to revisit same nodes. - while (IsIntermediate(&hero.instruction(), /*allowed_operand_count=*/1, - &hero.parent()) && + while (IsIntermediate(&hero.instruction(), /*allowed_operand_count=*/1) && hero.parent().ContainsInstruction(hero.GetOperand(0))) { hero = hero.GetOperand(0); } diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index c2ef21b8e7bb75..226f7b68df1b98 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -188,13 +188,8 @@ struct TransposeDescription { std::optional GetDescriptionForTiledTransposeEmitter( const HloInstruction& root, const HloInstruction& hero); -// Checks if the instruction is elementwise and only has a single user. If -// a fusion adaptor is provided, only checks for users within the fusion. If -// `add_single_user_check` is true, then it is also checked whether `instr` has -// at most 1 user. -bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count = 1, - const HloFusionAdaptor* fusion = nullptr, - bool add_single_user_check = false); +// Checks if the instruction is elementwise. +bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count = 1); // Log the given module if the VLOG level is >= level. void VLogModule(int level, const llvm::Module& module); diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc index 497cae1a2c91e5..37af036f42ece5 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc @@ -230,35 +230,6 @@ TEST_F(IrEmissionUtilsTest, FindReduceHeroEpilogueFusionHeroAlsoUsedAsNonHero) { EXPECT_EQ(result2.name(), "reduce.0"); } -TEST_F(IrEmissionUtilsTest, DoNotFindTransposeHeroEpilogueFusionTwoRootUsers) { - const char* hlo = R"( - HloModule module - - fused_computation { - param_0 = f32[64,32]{1,0} parameter(0) - transpose = f32[32,64]{1,0} transpose(param_0), dimensions={1,0} - bitcast.1 = f32[1,32,64]{2,1,0} bitcast(transpose) - sign.1 = f32[1,32,64]{2,1,0} sign(bitcast.1) - ROOT tuple.12 = (f32[1,32,64]{2,1,0}, f32[1,32,64]{2,1,0}) tuple(bitcast.1, sign.1) - } - - ENTRY main.7749 { - Arg_2.1 = f32[64,32]{1,0} parameter(0) - ROOT fusion = (f32[1,32,64]{2,1,0}, f32[1,32,64]{2,1,0}) fusion(Arg_2.1), kind=kInput, calls=fused_computation - } - )"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo)); - - HloInstruction* r = module->entry_computation()->root_instruction(); - auto fusion = HloFusionAdaptor::ForInstruction(r); - const auto& result = FindNonTrivialHero(fusion->GetRoots()[0]); - EXPECT_EQ(result.name(), "bitcast.1"); - const auto& result2 = FindNonTrivialHero(fusion->GetRoots()[1]); - EXPECT_EQ(result2.name(), "sign.1"); -} - TEST_F(IrEmissionUtilsTest, FindAnyTiledTransposeWithIntermediateBinaryOp) { const char* hlo = R"( HloModule module diff --git a/third_party/xla/xla/service/gpu/multi_output_fusion_test.cc b/third_party/xla/xla/service/gpu/multi_output_fusion_test.cc index 3ff83a2c4100b0..499848aacb270c 100644 --- a/third_party/xla/xla/service/gpu/multi_output_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/multi_output_fusion_test.cc @@ -1942,9 +1942,8 @@ ENTRY main { CheckGpuMultiOutputFusion(hlo, std::nullopt); } -// A variation of the test above, where no CSE was run, so we don't detect -// 'fusion' as a transpose fusion. -TEST_F(TransposeMultiOutputFusionTest, IncompatibleTransposesNoCSE) { +// A variation of the test above, where no CSE was run. +TEST_F(TransposeMultiOutputFusionTest, TransposesNoCSE) { const char* hlo = R"( HloModule module @@ -1974,20 +1973,7 @@ ENTRY main { } )"; - CheckGpuMultiOutputFusion(hlo, R"( -// CHECK: %fused_computation (param_0.1: f32[18,16,32], param_1.1: f32[32,16,18]) -> (f32[32,16,18], f32[18,32,16]) { -// CHECK-NEXT: [[param_0:%[^ ]+]] = f32[18,16,32]{2,1,0} parameter(0) -// CHECK-NEXT: [[s_1:%[^ ]+]] = f32[18,16,32]{2,1,0} sqrt([[param_0]]) -// CHECK-NEXT: [[t_1:%[^ ]+]] = f32[32,16,18]{2,1,0} transpose([[s_1]]), dimensions={2,1,0} -// CHECK-NEXT: [[param_1:%[^ ]+]] = f32[32,16,18]{2,1,0} parameter(1) -// CHECK-NEXT: [[sub:%[^ ]+]] = f32[32,16,18]{2,1,0} subtract([[t_1]], [[param_1]]) -// CHECK-NEXT: [[exp_1:%[^ ]+]] = f32[32,16,18]{2,1,0} exponential([[sub]]) -// CHECK-NEXT: [[exp_2:%[^ ]+]] = f32[32,16,18]{2,1,0} exponential([[sub]]) -// CHECK-NEXT: [[add:%[^ ]+]] = f32[32,16,18]{2,1,0} add([[exp_1]], [[exp_2]]) -// CHECK-NEXT: [[s_2:%[^ ]+]] = f32[18,16,32]{2,1,0} sqrt([[param_0]]) -// CHECK-NEXT: [[t_2:%[^ ]+]] = f32[18,32,16]{2,1,0} transpose([[s_2]]), dimensions={0,2,1} -// CHECK-NEXT: ROOT %{{.*}} = (f32[32,16,18]{2,1,0}, f32[18,32,16]{2,1,0}) tuple([[add]], [[t_2]]) -})"); + CheckGpuMultiOutputFusion(hlo, std::nullopt); } TEST_F(TransposeMultiOutputFusionTest, CopyAndInput) {