[go: nahoru, domu]

Skip to content

Commit

Permalink
Add support for batch padding policies into SharedBatchScheduler.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633192959
  • Loading branch information
tensorflower-gardener committed Jul 5, 2024
1 parent 1287587 commit d2e24de
Show file tree
Hide file tree
Showing 7 changed files with 558 additions and 4 deletions.
12 changes: 12 additions & 0 deletions tensorflow/core/kernels/batching_util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ cc_library(
"//tensorflow/core/lib/core:status",
"//tensorflow/core/platform:thread_annotations",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:criticality",
Expand All @@ -135,6 +136,7 @@ cc_library(
deps = [
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
Expand All @@ -148,8 +150,12 @@ cc_library(
srcs = ["batch_scheduler_utils.cc"],
hdrs = ["batch_scheduler_utils.h"],
deps = [
":batch_scheduler_hdrs",
":batch_stats",
"//tensorflow/core:portable_gif_internal",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/time",
],
)

Expand Down Expand Up @@ -183,7 +189,10 @@ tf_cc_test(
name = "batch_scheduler_utils_test",
srcs = ["batch_scheduler_utils_test.cc"],
deps = [
":batch_scheduler_hdrs",
":batch_scheduler_utils",
":batch_stats",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest_main",
],
)
Expand All @@ -200,6 +209,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:portable_gif_internal",
"//tensorflow/core/kernels/batching_util:batch_stats",
"//tensorflow/core/lib/core:errors",
"//tensorflow/core/lib/core:status",
"//tensorflow/core/platform:errors",
Expand Down Expand Up @@ -227,6 +237,7 @@ cc_library(
":batch_scheduler_utils",
":periodic_function_dynamic",
"//tensorflow/core:lib",
"//tensorflow/core/kernels/batching_util:batch_stats",
"//tensorflow/core/profiler/lib:connected_traceme",
"//tensorflow/core/profiler/lib:context_types_hdrs",
"//tensorflow/core/profiler/lib:traceme",
Expand All @@ -252,6 +263,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/kernels/batching_util:batch_scheduler_utils",
"//tensorflow/core/platform:status_matchers",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:fixed_array",
Expand Down
54 changes: 51 additions & 3 deletions tensorflow/core/kernels/batching_util/batch_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ limitations under the License.
#include <atomic>
#include <cstddef>
#include <deque>
#include <functional>
#include <iterator>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
Expand Down Expand Up @@ -252,7 +252,7 @@ int TaskQueue<TaskType>::size() const {
// accept new tasks; a closed one cannot. A batch is monotonic: initially it is
// open and tasks can be added to it; then it is closed and its set of tasks
// remains fixed for the remainder of its life. A closed batch cannot be re-
// opened. Tasks can never be removed from a batch.
// opened.
//
// Type parameter TaskType must be a subclass of BatchTask.
template <typename TaskType>
Expand Down Expand Up @@ -304,6 +304,15 @@ class Batch {
// Returns the TraceMe context id of this batch.
uint64 traceme_context_id() const;

// Attempts to trim this batch to a new, smaller size (not to be confused with
// the number of tasks in the batch). On success, the trimmed tasks go into
// 'out_trimmed_tasks' in the same order the tasks were in this batch.
//
// The method might not succeed if it needs to split a large task to hit the
// correct size.
void TryTrimToNewSize(
int new_size, std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks);

private:
mutable mutex mu_;

Expand Down Expand Up @@ -505,6 +514,45 @@ uint64 Batch<TaskType>::traceme_context_id() const {
return traceme_context_id_;
}

template <typename TaskType>
void Batch<TaskType>::TryTrimToNewSize(
int new_size, std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks) {
mutex_lock l(mu_);
DCHECK_GT(new_size, 0);
DCHECK_LT(new_size, size_);
DCHECK(out_trimmed_tasks.empty());

// Index of the first task to trim away. It is possible that it is the index
// of a task of size larger than 1 that will have to be split in order to get
// to the target new_size.
int32 first_task_to_move = 0;
// The sum of sizes of tasks i, where i < first_task_to_move.
int32 size_of_previous_tasks = 0;
while (size_of_previous_tasks + tasks_[first_task_to_move]->size() <=
new_size) {
size_of_previous_tasks += tasks_[first_task_to_move]->size();
first_task_to_move++;
// The loop must always stop before this check is tripped because new_size
// must never be larger than the size of the batch.
DCHECK_LT(first_task_to_move, tasks_.size());
}

// Check whether task 'first_task_to_move' will have to be split.
if (size_of_previous_tasks < new_size) {
// TODO: b/325954758 - Consider supporting splitting large tasks and then
// drop 'Try' from the method name.
return;
}
DCHECK_EQ(size_of_previous_tasks, new_size);

// Actually trim.
out_trimmed_tasks.reserve(tasks_.size() - first_task_to_move);
std::move(tasks_.begin() + first_task_to_move, tasks_.end(),
std::back_inserter(out_trimmed_tasks));
tasks_.resize(first_task_to_move);
size_ = new_size;
}

} // namespace serving
} // namespace tensorflow

Expand Down
51 changes: 50 additions & 1 deletion tensorflow/core/kernels/batching_util/batch_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ limitations under the License.
#include <optional>
#include <string>
#include <tuple>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/status_matchers.h"
#include "tensorflow/core/platform/test.h"
#include "tsl/platform/criticality.h"
Expand All @@ -37,6 +38,7 @@ namespace {

using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::Pointer;
using ::testing::Property;

TEST(MixedPriorityBatchingPolicyTest, InvalidAttrValueError) {
Expand Down Expand Up @@ -386,6 +388,53 @@ TEST(BatchTest, RemoveAllTasks) {
EXPECT_THAT(batch.RemoveAllTasks(), ::testing::IsEmpty()); // third call
}

TEST(BatchTest, TryTrimToNewSizeTrimsAndReturnsTrimmedElementsInOrder) {
Batch<FakeTask> batch;

auto task0 = new FakeTask(3);
batch.AddTask(std::unique_ptr<FakeTask>(task0));

auto task1 = new FakeTask(5);
batch.AddTask(std::unique_ptr<FakeTask>(task1));

auto task2 = new FakeTask(7);
batch.AddTask(std::unique_ptr<FakeTask>(task2));

auto task3 = new FakeTask(9);
batch.AddTask(std::unique_ptr<FakeTask>(task3));

std::vector<std::unique_ptr<FakeTask>> trimmed_tasks;
batch.TryTrimToNewSize(/* new_size= */ 8,
/* out_trimmed_tasks= */ trimmed_tasks);

EXPECT_EQ(batch.size(), 8);
EXPECT_EQ(batch.num_tasks(), 2);

EXPECT_THAT(trimmed_tasks, ElementsAre(Pointer(task2), Pointer(task3)));

batch.Close(); // Batch::~Batch blocks until the batch is closed.
}

TEST(BatchTest, TryTrimToNewSizeDoesNotTrimWhenItWouldNeedToSplitATask) {
Batch<FakeTask> batch;

auto task0 = new FakeTask(3);
batch.AddTask(std::unique_ptr<FakeTask>(task0));

auto task1 = new FakeTask(5);
batch.AddTask(std::unique_ptr<FakeTask>(task1));

std::vector<std::unique_ptr<FakeTask>> trimmed_tasks;
batch.TryTrimToNewSize(/* new_size= */ 4,
/* out_trimmed_tasks= */ trimmed_tasks);

EXPECT_EQ(batch.size(), 8);
EXPECT_EQ(batch.num_tasks(), 2);
EXPECT_TRUE(trimmed_tasks.empty());

batch.Close(); // Batch::~Batch blocks until the batch is closed.
}

} // namespace
} // namespace serving
} // namespace tensorflow
114 changes: 114 additions & 0 deletions tensorflow/core/kernels/batching_util/batch_scheduler_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_
#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_

#include <memory>
#include <optional>
#include <string_view>
#include <vector>

#include "absl/log/log.h"
#include "absl/time/time.h"
#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
#include "tensorflow/core/kernels/batching_util/batch_stats.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {
Expand All @@ -36,6 +43,113 @@ int GetPrevAllowedBatchSize(int batch_size,
const std::vector<int32>& allowed_batch_sizes,
bool disable_padding);

// Constants containing possible values for the batch_padding_policy argument
// of MaybeBatchDown. This argument specifies the policy that a batch scheduler
// is using when deciding what to do when, say, 18 requests need to be batched,
// but only 16 and 32 batch sizes are allowed. The following options are
// available.
//
// - PAD_UP: pad to size 32.
// - BATCH_DOWN: schedule a batch of size 16 and leave 2 requests in the
// batch buffer.
// - MINIMIZE_TPU_COST_PER_REQUEST: a smarter greedy policy that chooses
// to either PAD_UP or BATCH_DOWN so as to minimize the TPU costs per
// real request. In this case, it would compare (batch_16_cost / 16) and
// (batch_32_cost / 18).
//
inline constexpr std::string_view kBatchDownPolicy = "BATCH_DOWN";
inline constexpr std::string_view kPadUpPolicy = "PAD_UP";
inline constexpr std::string_view kMinimizeTpuCostPerRequestPolicy =
"MINIMIZE_TPU_COST_PER_REQUEST";

// Trims the batch to the next allowed batch size when possible and when
// configured by batch_padding_policy.
//
// When trimming, this function puts the trimmed tasks go into the
// out_trimmed_tasks vector in the same order as they were in the batch.
template <typename TaskType>
void MaybeBatchDown(Batch<TaskType>& batch,
const std::vector<int32>& allowed_batch_sizes,
bool disable_padding, std::string_view batch_padding_policy,
ModelBatchStats* model_batch_stats,
std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks) {
if (batch_padding_policy == kPadUpPolicy) {
// This is the default behavior of batch resource when it is given a batch
// size that doesn't match any of the allowed batch sizes.
return;
}
bool minimize_tpu_cost_per_request;
if (batch_padding_policy == kBatchDownPolicy) {
minimize_tpu_cost_per_request = false;
} else if (batch_padding_policy == kMinimizeTpuCostPerRequestPolicy) {
if (model_batch_stats == nullptr) {
LOG_FIRST_N(DFATAL, 1)
<< kMinimizeTpuCostPerRequestPolicy
<< " batch padding policy has been chosen "
"but no ModelBatchStats passed to the batch scheduler; will "
"fall back on the "
<< kPadUpPolicy << " policy.";
return;
}
minimize_tpu_cost_per_request = true;
} else {
LOG_FIRST_N(DFATAL, 1) << "Unsupported batch_padding_policy: "
<< batch_padding_policy << ", falling back on the "
<< kPadUpPolicy << " policy.";
return;
}

int32 batch_size = batch.size();

int32 pad_up_size =
GetNextAllowedBatchSize(batch_size, allowed_batch_sizes, disable_padding);
if (pad_up_size == batch_size) {
return; // Good, no padding is necessary.
}

int32 batch_down_size =
GetPrevAllowedBatchSize(batch_size, allowed_batch_sizes, disable_padding);
if (batch_down_size == batch_size) {
return; // Can't batch down (e.g. no smaller batch size available).
}

if (minimize_tpu_cost_per_request) {
// TODO: b/325954758 - Consider logging a warning here or elsewhere if
// a larger batch doesn't cost meaningfully cheaper than a smaller batch.
// TODO: b/325954758 - Consider logging a warning here or elsewhere if a
// smaller batch costs unreasonably cheaper than a larger one (assuming
// a batch cost model = constant_cost + batch_size * per_element_cost).
// TODO: b/325954758 - Consider occasionally picking either batch size so
// that we learn fresh costs of each batch size. For this code, it is not a
// large priority though because if we are in between two allowed batch
// sizes (say, 16 and 32), chances are that will occasionally organically
// get batches of exact sizes 16 and 32 (and then we pick those
// unconditionally). But if we explicitly occasionally explored other batch
// sizes, we wouldn't have to rely on this "chances are". For other
// applications of batch costs, we might also want to occasionally explore
// all allowed batch sizes and not just 16 and 32 from this example.
std::optional<absl::Duration> down_batch_cost =
model_batch_stats->batch_size(batch_down_size).tpu_cost().mean();
std::optional<absl::Duration> up_batch_cost =
model_batch_stats->batch_size(pad_up_size).tpu_cost().mean();
if (!down_batch_cost.has_value() || !up_batch_cost.has_value()) {
// We have no data about batch costs, let's just do nothing.
return;
}

auto batch_down_cost_per_request = *down_batch_cost / batch_down_size;
auto pad_up_cost_per_request = *up_batch_cost / batch_size;

if (pad_up_cost_per_request < batch_down_cost_per_request) {
// Abort batching down because it's cheaper to pad up.
return;
}
}

// Batch down.
batch.TryTrimToNewSize(batch_down_size, out_trimmed_tasks);
}

} // namespace serving
} // namespace tensorflow

Expand Down
Loading

0 comments on commit d2e24de

Please sign in to comment.