[go: nahoru, domu]

Skip to content

Commit

Permalink
Introduce the --tensorflow_batch_padding_policy flag.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 628043317
  • Loading branch information
tensorflower-gardener committed May 16, 2024
1 parent c9e8a7f commit ebcdeef
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tensorflow/core/kernels/batching_util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ cc_library(
hdrs = ["batch_scheduler_utils.h"],
deps = [
"//tensorflow/core:portable_gif_internal",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings:string_view",
],
)

Expand Down
47 changes: 47 additions & 0 deletions tensorflow/core/kernels/batching_util/batch_scheduler_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,28 @@ limitations under the License.

#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"

#include <string>
#include <vector>

#include "absl/flags/flag.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"

ABSL_FLAG(tensorflow::serving::BatchPaddingPolicy,
tensorflow_batch_padding_policy,
tensorflow::serving::BatchPaddingPolicy::kPadUp,
"The policy that a batch schduler is using when deciding what to do "
"when, say, 18 requests need to be batched, but only 16 and 32 batch "
"sizes are allowed. The following options are available. PAD_UP: pad "
"to size 32. BATCH_DOWN: schedule a batch of size 16 and leave 2 "
"requests in the batch buffer. MINIMIZE_TPU_COST_PER_REQUEST: a "
"smarter greedy policy that chooses to either PAD_UP or BATCH_DOWN "
"so as to minimize the TPU costs per real request. In this case, it "
"would compare (batch_16_cost / 16) and (batch_32_cost / 18). "
"WARNING: not all batch schedulers might support this option.");

namespace tensorflow {
namespace serving {

Expand All @@ -40,5 +57,35 @@ int GetNextAllowedBatchSize(int batch_size,
return batch_size;
}

bool AbslParseFlag(absl::string_view text, BatchPaddingPolicy* out,
std::string* error) {
if (text == "PAD_UP") {
*out = BatchPaddingPolicy::kPadUp;
return true;
}
if (text == "BATCH_DOWN") {
*out = BatchPaddingPolicy::kBatchDown;
return true;
}
if (text == "MINIMIZE_TPU_COST_PER_REQUEST") {
*out = BatchPaddingPolicy::kMinimizeTpuCostPerRequest;
return true;
}
*error = "unrecognized batching policy string";
return false;
}

string AbslUnparseFlag(BatchPaddingPolicy in) {
switch (in) {
case BatchPaddingPolicy::kPadUp:
return "PAD_UP";
case BatchPaddingPolicy::kBatchDown:
return "BATCH_DOWN";
case BatchPaddingPolicy::kMinimizeTpuCostPerRequest:
return "MINIMIZE_TPU_COST_PER_REQUEST";
}
CHECK(FATAL) << "Unrecognized BatchPaddingPolicy enum value."; // Crash OK
}

} // namespace serving
} // namespace tensorflow
23 changes: 23 additions & 0 deletions tensorflow/core/kernels/batching_util/batch_scheduler_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,21 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_
#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_

#include <string>
#include <vector>

#include "absl/flags/declare.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow::serving {
enum class BatchPaddingPolicy; // Forward-declaring for the ABSL_DECLARE_FLAG.
} // namespace tensorflow::serving

// Exposed for testing only.
ABSL_DECLARE_FLAG(tensorflow::serving::BatchPaddingPolicy,
tensorflow_batch_padding_policy);

namespace tensorflow {
namespace serving {

Expand All @@ -30,6 +41,18 @@ int GetNextAllowedBatchSize(int batch_size,
const std::vector<int32>& allowed_batch_sizes,
bool disable_padding);

// See the description of the --tensorflow_batch_padding_policy flag (in the
// .cc file) for the documentation.
enum class BatchPaddingPolicy {
kPadUp,
kBatchDown,
kMinimizeTpuCostPerRequest,
};

bool AbslParseFlag(absl::string_view text, BatchPaddingPolicy* out,
std::string* error);
std::string AbslUnparseFlag(BatchPaddingPolicy in);

} // namespace serving
} // namespace tensorflow

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"

#include <string>

#include <gtest/gtest.h>

namespace tensorflow {
Expand Down Expand Up @@ -42,6 +44,33 @@ TEST(GetNextAllowedBatchSizeTest, GreaterThanAllowedBatchSize) {
EXPECT_EQ(GetNextAllowedBatchSize(10, {2, 4, 8}, false), 10);
}

TEST(BatchPaddingPolicyTest, AbslParseFlag) {
std::string error;
BatchPaddingPolicy policy;

EXPECT_TRUE(AbslParseFlag("PAD_UP", &policy, &error));
EXPECT_EQ(policy, BatchPaddingPolicy::kPadUp);
EXPECT_EQ(error, "");

EXPECT_TRUE(AbslParseFlag("BATCH_DOWN", &policy, &error));
EXPECT_EQ(policy, BatchPaddingPolicy::kBatchDown);
EXPECT_EQ(error, "");

EXPECT_TRUE(AbslParseFlag("MINIMIZE_TPU_COST_PER_REQUEST", &policy, &error));
EXPECT_EQ(policy, BatchPaddingPolicy::kMinimizeTpuCostPerRequest);
EXPECT_EQ(error, "");

EXPECT_FALSE(AbslParseFlag("cucumber", &policy, &error));
EXPECT_NE(error, "");
}

TEST(BatchPaddingPolicyTest, AbslUnparseFlag) {
EXPECT_EQ(AbslUnparseFlag(BatchPaddingPolicy::kPadUp), "PAD_UP");
EXPECT_EQ(AbslUnparseFlag(BatchPaddingPolicy::kBatchDown), "BATCH_DOWN");
EXPECT_EQ(AbslUnparseFlag(BatchPaddingPolicy::kMinimizeTpuCostPerRequest),
"MINIMIZE_TPU_COST_PER_REQUEST");
}

} // namespace

} // namespace serving
Expand Down

0 comments on commit ebcdeef

Please sign in to comment.