[go: nahoru, domu]

Skip to content

Commit

Permalink
Introduce the MaybeBatchDown helper method, only supporting kBatchDow…
Browse files Browse the repository at this point in the history
…n at this time.

This method is not used anywhere yet; It will be called from the SharedBatchScheduler in a follow up CL.

Reverts changelist 602615649

PiperOrigin-RevId: 609369180
  • Loading branch information
tensorflower-gardener committed May 24, 2024
1 parent eb1a971 commit ea18ce9
Show file tree
Hide file tree
Showing 10 changed files with 435 additions and 87 deletions.
9 changes: 9 additions & 0 deletions tensorflow/core/kernels/batching_util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ cc_library(
"//tensorflow/core/lib/core:status",
"//tensorflow/core/platform:thread_annotations",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:criticality",
Expand All @@ -113,6 +114,7 @@ cc_library(
deps = [
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
Expand All @@ -126,7 +128,12 @@ cc_library(
srcs = ["batch_scheduler_utils.cc"],
hdrs = ["batch_scheduler_utils.h"],
deps = [
":batch_scheduler_hdrs",
"//tensorflow/core:portable_gif_internal",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings:string_view",
],
)

Expand Down Expand Up @@ -160,7 +167,9 @@ tf_cc_test(
name = "batch_scheduler_utils_test",
srcs = ["batch_scheduler_utils_test.cc"],
deps = [
":batch_scheduler_hdrs",
":batch_scheduler_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_googletest//:gtest_main",
],
)
Expand Down
49 changes: 47 additions & 2 deletions tensorflow/core/kernels/batching_util/batch_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ limitations under the License.
#include <atomic>
#include <cstddef>
#include <deque>
#include <functional>
#include <iterator>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
Expand Down Expand Up @@ -304,6 +304,15 @@ class Batch {
// Returns the TraceMe context id of this batch.
uint64 traceme_context_id() const;

// Attempts to trim this batch to a new, smaller size (not to be confused with
// the number of tasks in the batch). On success, the trimmed tasks go into
// 'out_trimmed_tasks' in the same order the tasks were in this batch.
//
// The method might not succeed if it needs to split a large task to hit the
// correct size.
void TryTrimToNewSize(
int new_size, std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks);

private:
mutable mutex mu_;

Expand Down Expand Up @@ -505,6 +514,42 @@ uint64 Batch<TaskType>::traceme_context_id() const {
return traceme_context_id_;
}

template <typename TaskType>
void Batch<TaskType>::TryTrimToNewSize(
int new_size, std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks) {
mutex_lock l(mu_);
DCHECK_GT(new_size, 0);
DCHECK_LT(new_size, size_);
DCHECK(out_trimmed_tasks.empty());

// Index of the first task to trim away. It is possible that it is the index
// of a task of size larger than 1 that will have to be split in order to get
// to the target new_size.
int32 first_task_to_move = 0;
// The sum of sizes of tasks i, where i < first_task_to_move.
int32 size_of_previous_tasks = 0;
while (size_of_previous_tasks + tasks_[first_task_to_move]->size() <=
new_size) {
size_of_previous_tasks += tasks_[first_task_to_move]->size();
first_task_to_move++;
}

// Check whether task 'first_task_to_move' will have to be split.
if (size_of_previous_tasks < new_size) {
// TODO: b/325954758 - Consider supporting splitting large tasks and then
// drop 'Try' from the method name.
return;
}
DCHECK_EQ(size_of_previous_tasks, new_size);

// Actually trim.
out_trimmed_tasks.reserve(tasks_.size() - first_task_to_move);
std::move(tasks_.begin() + first_task_to_move, tasks_.end(),
std::back_inserter(out_trimmed_tasks));
tasks_.resize(first_task_to_move);
size_ = new_size;
}

} // namespace serving
} // namespace tensorflow

Expand Down
51 changes: 50 additions & 1 deletion tensorflow/core/kernels/batching_util/batch_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ limitations under the License.
#include <optional>
#include <string>
#include <tuple>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/status_matchers.h"
#include "tensorflow/core/platform/test.h"
#include "tsl/platform/criticality.h"
Expand All @@ -37,6 +38,7 @@ namespace {

using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::Pointer;
using ::testing::Property;

TEST(MixedPriorityBatchingPolicyTest, InvalidAttrValueError) {
Expand Down Expand Up @@ -386,6 +388,53 @@ TEST(BatchTest, RemoveAllTasks) {
EXPECT_THAT(batch.RemoveAllTasks(), ::testing::IsEmpty()); // third call
}

TEST(BatchTest, TryTrimToNewSizeTrimsAndReturnsTrimmedElementsInOrder) {
Batch<FakeTask> batch;

auto task0 = new FakeTask(3);
batch.AddTask(std::unique_ptr<FakeTask>(task0));

auto task1 = new FakeTask(5);
batch.AddTask(std::unique_ptr<FakeTask>(task1));

auto task2 = new FakeTask(7);
batch.AddTask(std::unique_ptr<FakeTask>(task2));

auto task3 = new FakeTask(9);
batch.AddTask(std::unique_ptr<FakeTask>(task3));

std::vector<std::unique_ptr<FakeTask>> trimmed_tasks;
batch.TryTrimToNewSize(/* new_size= */ 8,
/* out_trimmed_tasks= */ trimmed_tasks);

EXPECT_EQ(batch.size(), 8);
EXPECT_EQ(batch.num_tasks(), 2);

EXPECT_THAT(trimmed_tasks, ElementsAre(Pointer(task2), Pointer(task3)));

batch.Close(); // Batch::~Batch blocks until the batch is closed.
}

TEST(BatchTest, TryTrimToNewSizeDoesNotTrimWhenItWouldNeedToSplitATask) {
Batch<FakeTask> batch;

auto task0 = new FakeTask(3);
batch.AddTask(std::unique_ptr<FakeTask>(task0));

auto task1 = new FakeTask(5);
batch.AddTask(std::unique_ptr<FakeTask>(task1));

std::vector<std::unique_ptr<FakeTask>> trimmed_tasks;
batch.TryTrimToNewSize(/* new_size= */ 4,
/* out_trimmed_tasks= */ trimmed_tasks);

EXPECT_EQ(batch.size(), 8);
EXPECT_EQ(batch.num_tasks(), 2);
EXPECT_TRUE(trimmed_tasks.empty());

batch.Close(); // Batch::~Batch blocks until the batch is closed.
}

} // namespace
} // namespace serving
} // namespace tensorflow
72 changes: 72 additions & 0 deletions tensorflow/core/kernels/batching_util/batch_scheduler_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,30 @@ limitations under the License.

#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"

#include <algorithm>
#include <string>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/flags/flag.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"

ABSL_FLAG(tensorflow::serving::BatchPaddingPolicy,
tensorflow_batch_padding_policy,
tensorflow::serving::BatchPaddingPolicy::kPadUp,
"The policy that a batch schduler is using when deciding what to do "
"when, say, 18 requests need to be batched, but only 16 and 32 batch "
"sizes are allowed. The following options are available. PAD_UP: pad "
"to size 32. BATCH_DOWN: schedule a batch of size 16 and leave 2 "
"requests in the batch buffer. MINIMIZE_TPU_COST_PER_REQUEST: a "
"smarter greedy policy that chooses to either PAD_UP or BATCH_DOWN "
"so as to minimize the TPU costs per real request. In this case, it "
"would compare (batch_16_cost / 16) and (batch_32_cost / 18). "
"WARNING: not all batch schedulers might support this option.");

namespace tensorflow {
namespace serving {

Expand All @@ -40,5 +59,58 @@ int GetNextAllowedBatchSize(int batch_size,
return batch_size;
}

int32 GetPrevAllowedBatchSize(int batch_size,
const std::vector<int32>& allowed_batch_sizes,
bool disable_padding) {
if (disable_padding || allowed_batch_sizes.empty()) {
return batch_size;
}

DCHECK(absl::c_is_sorted(allowed_batch_sizes));
DCHECK_GT(batch_size, 0);

// First from the end allowed batch size not larger than batch_size.
auto result = std::find_if(
allowed_batch_sizes.rbegin(), allowed_batch_sizes.rend(),
[&](int allowed_size) { return allowed_size <= batch_size; });

if (result == allowed_batch_sizes.rend()) {
// No such element exists.
return batch_size;
}

return *result;
}

bool AbslParseFlag(absl::string_view text, BatchPaddingPolicy* out,
std::string* error) {
if (text == "PAD_UP") {
*out = BatchPaddingPolicy::kPadUp;
return true;
}
if (text == "BATCH_DOWN") {
*out = BatchPaddingPolicy::kBatchDown;
return true;
}
if (text == "MINIMIZE_TPU_COST_PER_REQUEST") {
*out = BatchPaddingPolicy::kMinimizeTpuCostPerRequest;
return true;
}
*error = "unrecognized batching policy string";
return false;
}

string AbslUnparseFlag(BatchPaddingPolicy in) {
switch (in) {
case BatchPaddingPolicy::kPadUp:
return "PAD_UP";
case BatchPaddingPolicy::kBatchDown:
return "BATCH_DOWN";
case BatchPaddingPolicy::kMinimizeTpuCostPerRequest:
return "MINIMIZE_TPU_COST_PER_REQUEST";
}
CHECK(FATAL) << "Unrecognized BatchPaddingPolicy enum value."; // Crash OK
}

} // namespace serving
} // namespace tensorflow
73 changes: 73 additions & 0 deletions tensorflow/core/kernels/batching_util/batch_scheduler_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,25 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_
#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_

#include <memory>
#include <string>
#include <vector>

#include "absl/flags/declare.h"
#include "absl/flags/flag.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow::serving {
enum class BatchPaddingPolicy; // Forward-declaring for the ABSL_DECLARE_FLAG.
} // namespace tensorflow::serving

// Exposed for testing only.
ABSL_DECLARE_FLAG(tensorflow::serving::BatchPaddingPolicy,
tensorflow_batch_padding_policy);

namespace tensorflow {
namespace serving {

Expand All @@ -30,6 +45,64 @@ int GetNextAllowedBatchSize(int batch_size,
const std::vector<int32>& allowed_batch_sizes,
bool disable_padding);

// Returns the largest allowed batch size that is smaller than or equal to
// batch_size. Returns batch_size if no such size exists.
int GetPrevAllowedBatchSize(int batch_size,
const std::vector<int32>& allowed_batch_sizes,
bool disable_padding);

// See the description of the --tensorflow_batch_padding_policy flag (in the
// .cc file) for the documentation.
enum class BatchPaddingPolicy {
kPadUp,
kBatchDown,
kMinimizeTpuCostPerRequest,
};
bool AbslParseFlag(absl::string_view text, BatchPaddingPolicy* out,
std::string* error);
std::string AbslUnparseFlag(BatchPaddingPolicy in);

// Trims the batch to the next allowed batch size when possible and when
// configured by the --tensorflow_batch_padding_policy flag.
//
// When trimming, this function puts the trimmed tasks go into the
// out_trimmed_tasks vector in the same order as they were in the batch.
template <typename TaskType>
void MaybeBatchDown(Batch<TaskType>& batch,
const std::vector<int32>& allowed_batch_sizes,
bool disable_padding,
std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks) {
switch (absl::GetFlag(FLAGS_tensorflow_batch_padding_policy)) {
case BatchPaddingPolicy::kPadUp:
// This is the default behavior of batch resource when it is given a batch
// size that doesn't match any of the allowed batch sizes.
return;
case BatchPaddingPolicy::kBatchDown:
// Continue with this method.
break;
case BatchPaddingPolicy::kMinimizeTpuCostPerRequest:
LOG(DFATAL) << "BatchPaddingPolicy::kMinimizeTpuCostPerRequest is not "
"yet implemented, falling back on kBatchDown.";
break;
}

int32 batch_size = batch.size();

int32 pad_up_size =
GetNextAllowedBatchSize(batch_size, allowed_batch_sizes, disable_padding);
if (pad_up_size == batch_size) {
return; // Good, no padding is necessary.
}

int32 batch_down_size =
GetPrevAllowedBatchSize(batch_size, allowed_batch_sizes, disable_padding);
if (batch_down_size == batch_size) {
return; // Can't batch down (e.g. no smaller batch size available).
}

batch.TryTrimToNewSize(batch_down_size, out_trimmed_tasks);
}

} // namespace serving
} // namespace tensorflow

Expand Down
Loading

0 comments on commit ea18ce9

Please sign in to comment.