[go: nahoru, domu]

Skip to content

Commit

Permalink
Register TPU costs in the new batch_stats module.
Browse files Browse the repository at this point in the history
Reverts changelist 602615649

PiperOrigin-RevId: 630496098
  • Loading branch information
tensorflower-gardener committed May 24, 2024
1 parent eb1a971 commit 07174c1
Show file tree
Hide file tree
Showing 14 changed files with 741 additions and 87 deletions.
37 changes: 37 additions & 0 deletions tensorflow/core/kernels/batching_util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,30 @@ cc_library(
],
)

cc_library(
name = "batch_stats",
hdrs = ["batch_stats.h"],
deps = [
"//tensorflow/core:framework_lite",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/time",
],
)

tf_cc_test(
name = "batch_stats_test",
srcs = ["batch_stats_test.cc"],
deps = [
":batch_stats",
"//tensorflow/core:test",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "batch_input_task",
hdrs = ["batch_input_task.h"],
Expand Down Expand Up @@ -100,6 +124,7 @@ cc_library(
"//tensorflow/core/lib/core:status",
"//tensorflow/core/platform:thread_annotations",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:criticality",
Expand All @@ -113,6 +138,7 @@ cc_library(
deps = [
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
Expand All @@ -126,7 +152,12 @@ cc_library(
srcs = ["batch_scheduler_utils.cc"],
hdrs = ["batch_scheduler_utils.h"],
deps = [
":batch_scheduler_hdrs",
"//tensorflow/core:portable_gif_internal",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings:string_view",
],
)

Expand Down Expand Up @@ -160,7 +191,9 @@ tf_cc_test(
name = "batch_scheduler_utils_test",
srcs = ["batch_scheduler_utils_test.cc"],
deps = [
":batch_scheduler_hdrs",
":batch_scheduler_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_googletest//:gtest_main",
],
)
Expand Down Expand Up @@ -402,6 +435,7 @@ cc_library(
":adaptive_shared_batch_scheduler",
":batch_scheduler",
":batch_scheduler_utils",
":batch_stats",
":concat_split_util",
":input_split_metadata",
":shared_batch_scheduler",
Expand Down Expand Up @@ -457,12 +491,15 @@ tf_cc_test(
srcs = ["batch_resource_base_test.cc"],
deps = [
":batch_resource_base",
":batch_stats",
"//tensorflow/core:framework",
"//tensorflow/core/common_runtime:cost_constants",
"//tensorflow/core/common_runtime:cost_measurement",
"//tensorflow/core/common_runtime:cost_measurement_registry",
"//tensorflow/core/common_runtime:no_op_cost_measurement",
"//tensorflow/core/common_runtime:request_cost",
"//tensorflow/core/framework:types_proto_cc",
"@com_google_absl//absl/random",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest_main",
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/core/kernels/batching_util/batch_resource_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"
#include "tensorflow/core/kernels/batching_util/batch_stats.h"
#include "tensorflow/core/kernels/batching_util/concat_split_util.h"
#include "tensorflow/core/kernels/batching_util/input_split_metadata.h"
#include "tensorflow/core/kernels/batching_util/threadsafe_status.h"
Expand Down Expand Up @@ -1235,6 +1236,15 @@ void BatchResourceBase::SplitBatchCostsAndRecordMetrics(
absl::StrCat(cost_type, kNoSmearSuffix),
total_cost / processed_size * batch.size());

if (cost_type == kTpuCostName) {
// Register TPU cost for in-process use.
GlobalBatchStats()
.model(model_name)
.batch_size(processed_size)
.tpu_cost()
.Register(total_cost);
}

for (int i = 0; i < batch.num_tasks(); i++) {
RequestCost* request_cost = batch.task(i).request_cost;
// Skip recording the cost if the request_cost is null.
Expand Down
35 changes: 35 additions & 0 deletions tensorflow/core/kernels/batching_util/batch_resource_base_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ limitations under the License.
#include <gtest/gtest.h>
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "tensorflow/core/common_runtime/cost_constants.h"
#include "tensorflow/core/common_runtime/cost_measurement.h"
#include "tensorflow/core/common_runtime/cost_measurement_registry.h"
#include "tensorflow/core/common_runtime/request_cost.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/batching_util/batch_stats.h"
#include "tsl/platform/criticality.h"

namespace tensorflow {
Expand Down Expand Up @@ -294,6 +296,39 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SplitOnlyNonZeroCostTypes) {
UnorderedElementsAre(Pair("test_tpu", absl::Milliseconds(100))))));
}

TEST(SplitBatchCostsAndRecordMetricsTest, UpdatesGlobalBatchStats) {
// Create batch_cost_measurements with one TPU cost.
class FakeTpuCostMeasurement : public CostMeasurement {
public:
using CostMeasurement::CostMeasurement;
absl::Duration GetTotalCost() override { return absl::Hours(555); }
absl::string_view GetCostType() const override { return kTpuCostName; }
};
CostMeasurement::Context context{/* is_per_query= */ false};
std::vector<std::unique_ptr<CostMeasurement>> batch_cost_measurements;
batch_cost_measurements.push_back(
std::make_unique<FakeTpuCostMeasurement>(context));

// Create a non-empty batch.
BatchResourceBase::BatchT batch;
batch.AddTask(MakeBatchTask(/* task_size= */ 1, nullptr));
batch.Close();

// Pick a model name that no other test would pick. This is so that we are
// sure that the CPU cost for this model name has either never been reported
// before or, if this test is executed multiple times, has been reported by
// this only.
const char kModelName[] = __FILE__;

BatchResourceBase::SplitBatchCostsAndRecordMetrics(
kModelName, batch_cost_measurements,
/* processed_size= */ 17, batch);

EXPECT_EQ(
GlobalBatchStats().model(kModelName).batch_size(17).tpu_cost().mean(),
absl::Hours(555));
}

} // namespace
} // namespace serving
} // namespace tensorflow
49 changes: 47 additions & 2 deletions tensorflow/core/kernels/batching_util/batch_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ limitations under the License.
#include <atomic>
#include <cstddef>
#include <deque>
#include <functional>
#include <iterator>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
Expand Down Expand Up @@ -304,6 +304,15 @@ class Batch {
// Returns the TraceMe context id of this batch.
uint64 traceme_context_id() const;

// Attempts to trim this batch to a new, smaller size (not to be confused with
// the number of tasks in the batch). On success, the trimmed tasks go into
// 'out_trimmed_tasks' in the same order the tasks were in this batch.
//
// The method might not succeed if it needs to split a large task to hit the
// correct size.
void TryTrimToNewSize(
int new_size, std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks);

private:
mutable mutex mu_;

Expand Down Expand Up @@ -505,6 +514,42 @@ uint64 Batch<TaskType>::traceme_context_id() const {
return traceme_context_id_;
}

template <typename TaskType>
void Batch<TaskType>::TryTrimToNewSize(
int new_size, std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks) {
mutex_lock l(mu_);
DCHECK_GT(new_size, 0);
DCHECK_LT(new_size, size_);
DCHECK(out_trimmed_tasks.empty());

// Index of the first task to trim away. It is possible that it is the index
// of a task of size larger than 1 that will have to be split in order to get
// to the target new_size.
int32 first_task_to_move = 0;
// The sum of sizes of tasks i, where i < first_task_to_move.
int32 size_of_previous_tasks = 0;
while (size_of_previous_tasks + tasks_[first_task_to_move]->size() <=
new_size) {
size_of_previous_tasks += tasks_[first_task_to_move]->size();
first_task_to_move++;
}

// Check whether task 'first_task_to_move' will have to be split.
if (size_of_previous_tasks < new_size) {
// TODO: b/325954758 - Consider supporting splitting large tasks and then
// drop 'Try' from the method name.
return;
}
DCHECK_EQ(size_of_previous_tasks, new_size);

// Actually trim.
out_trimmed_tasks.reserve(tasks_.size() - first_task_to_move);
std::move(tasks_.begin() + first_task_to_move, tasks_.end(),
std::back_inserter(out_trimmed_tasks));
tasks_.resize(first_task_to_move);
size_ = new_size;
}

} // namespace serving
} // namespace tensorflow

Expand Down
51 changes: 50 additions & 1 deletion tensorflow/core/kernels/batching_util/batch_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ limitations under the License.
#include <optional>
#include <string>
#include <tuple>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/status_matchers.h"
#include "tensorflow/core/platform/test.h"
#include "tsl/platform/criticality.h"
Expand All @@ -37,6 +38,7 @@ namespace {

using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::Pointer;
using ::testing::Property;

TEST(MixedPriorityBatchingPolicyTest, InvalidAttrValueError) {
Expand Down Expand Up @@ -386,6 +388,53 @@ TEST(BatchTest, RemoveAllTasks) {
EXPECT_THAT(batch.RemoveAllTasks(), ::testing::IsEmpty()); // third call
}

TEST(BatchTest, TryTrimToNewSizeTrimsAndReturnsTrimmedElementsInOrder) {
Batch<FakeTask> batch;

auto task0 = new FakeTask(3);
batch.AddTask(std::unique_ptr<FakeTask>(task0));

auto task1 = new FakeTask(5);
batch.AddTask(std::unique_ptr<FakeTask>(task1));

auto task2 = new FakeTask(7);
batch.AddTask(std::unique_ptr<FakeTask>(task2));

auto task3 = new FakeTask(9);
batch.AddTask(std::unique_ptr<FakeTask>(task3));

std::vector<std::unique_ptr<FakeTask>> trimmed_tasks;
batch.TryTrimToNewSize(/* new_size= */ 8,
/* out_trimmed_tasks= */ trimmed_tasks);

EXPECT_EQ(batch.size(), 8);
EXPECT_EQ(batch.num_tasks(), 2);

EXPECT_THAT(trimmed_tasks, ElementsAre(Pointer(task2), Pointer(task3)));

batch.Close(); // Batch::~Batch blocks until the batch is closed.
}

TEST(BatchTest, TryTrimToNewSizeDoesNotTrimWhenItWouldNeedToSplitATask) {
Batch<FakeTask> batch;

auto task0 = new FakeTask(3);
batch.AddTask(std::unique_ptr<FakeTask>(task0));

auto task1 = new FakeTask(5);
batch.AddTask(std::unique_ptr<FakeTask>(task1));

std::vector<std::unique_ptr<FakeTask>> trimmed_tasks;
batch.TryTrimToNewSize(/* new_size= */ 4,
/* out_trimmed_tasks= */ trimmed_tasks);

EXPECT_EQ(batch.size(), 8);
EXPECT_EQ(batch.num_tasks(), 2);
EXPECT_TRUE(trimmed_tasks.empty());

batch.Close(); // Batch::~Batch blocks until the batch is closed.
}

} // namespace
} // namespace serving
} // namespace tensorflow
Loading

0 comments on commit 07174c1

Please sign in to comment.