[go: nahoru, domu]

Skip to content

Commit

Permalink
Enable profiler tracing support for threadpools
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621967537
  • Loading branch information
cliveverghese authored and tensorflower-gardener committed May 17, 2024
1 parent 4d73219 commit d8e28bb
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 4 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/data/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ cc_library(
"//tensorflow/core/platform:status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@local_tsl//tsl/profiler/backends/cpu:threadpool_listener_state",
],
)

Expand Down
1 change: 1 addition & 0 deletions third_party/xla/third_party/tsl/tsl/platform/default/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ cc_library(
"//tsl/platform:strcat",
"//tsl/platform:stringpiece",
"//tsl/platform:types",
"//tsl/profiler/backends/cpu:threadpool_listener_state",
],
alwayslink = True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_TSL_PLATFORM_DEFAULT_TRACING_IMPL_H_
#define TENSORFLOW_TSL_PLATFORM_DEFAULT_TRACING_IMPL_H_

#ifndef IS_MOBILE_PLATFORM
#include "tsl/profiler/backends/cpu/threadpool_listener_state.h"
#endif
// Stub implementations of tracing functionality.

// Definitions that do nothing for platforms that don't have underlying thread
Expand All @@ -33,7 +36,13 @@ limitations under the License.
namespace tsl {
namespace tracing {

inline bool EventCollector::IsEnabled() { return false; }
inline bool EventCollector::IsEnabled() {
#ifndef IS_MOBILE_PLATFORM
return tsl::profiler::threadpool_listener::IsEnabled();
#else
return false;
#endif
}

} // namespace tracing
} // namespace tsl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,5 @@ cc_library(
name = "threadpool_listener_state",
srcs = ["threadpool_listener_state.cc"],
hdrs = ["threadpool_listener_state.h"],
visibility = internal_visibility([
"//tsl/platform:__subpackages__",
]),
visibility = ["//visibility:public"],
)
70 changes: 70 additions & 0 deletions third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,25 @@ limitations under the License.
==============================================================================*/
#include "xla/backends/profiler/cpu/host_tracer.h"

#include <cstdint>
#include <memory>
#include <optional>
#include <ostream>
#include <string>

#include <gtest/gtest.h>
#include "absl/types/optional.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/blocking_counter.h"
#include "tsl/platform/env.h"
#include "tsl/platform/test.h"
#include "tsl/platform/threadpool.h"
#include "tsl/platform/types.h"
#include "tsl/profiler/lib/profiler_interface.h"
#include "tsl/profiler/lib/traceme.h"
#include "tsl/profiler/protobuf/xplane.pb.h"
#include "tsl/profiler/utils/tf_xplane_visitor.h"
#include "tsl/profiler/utils/timespan.h"
#include "tsl/profiler/utils/xplane_schema.h"
#include "tsl/profiler/utils/xplane_visitor.h"

Expand All @@ -37,8 +43,11 @@ namespace {
using ::tsl::Env;
using ::tsl::Thread;
using ::tsl::ThreadOptions;
using ::tsl::profiler::StatType;
using ::tsl::profiler::Timespan;
using ::tsl::profiler::TraceMe;
using ::tsl::profiler::XEventVisitor;
using ::tsl::profiler::XLineVisitor;
using ::tsl::profiler::XPlaneVisitor;
using ::tsl::profiler::XStatVisitor;

Expand Down Expand Up @@ -153,6 +162,67 @@ TEST(HostTracerTest, CollectsTraceMeEventsAsXSpace) {
EXPECT_EQ(e6.DisplayName(), "Iterator::ParallelMap");
}

TEST(HostTracerTest, CollectEventsFromThreadPool) {
tsl::thread::ThreadPool thread_pool(/*env=*/Env::Default(),
/*name=*/"HostTracerTest",
/*num_threads=*/1);
tsl::BlockingCounter counter(1);
auto tracer = CreateHostTracer({});
TF_EXPECT_OK(tracer->Start());
thread_pool.Schedule([&counter] {
TraceMe traceme("hello");
counter.DecrementCount();
});
counter.Wait();
TF_EXPECT_OK(tracer->Stop());
tensorflow::profiler::XSpace space;
TF_EXPECT_OK(tracer->CollectData(&space));

EXPECT_THAT(space.planes(), testing::SizeIs(1));
XPlaneVisitor xplane = tsl::profiler::CreateTfXPlaneVisitor(&space.planes(0));

bool has_record_event = false;
bool has_start_region_event = false;
bool has_end_region_event = false;
int64_t record_region_id = 0;
int64_t start_region_id = 0;

Timespan region_timespan;
Timespan traceme_timespan;

xplane.ForEachLine([&](const XLineVisitor& line) {
line.ForEachEvent([&](const XEventVisitor& event) {
if (event.Name() == tsl::profiler::kThreadpoolListenerRecord) {
has_record_event = true;
const auto& stat = event.GetStat(StatType::kProducerId);
EXPECT_TRUE(stat.has_value());
record_region_id = stat->IntOrUintValue();
} else if (event.Name() ==
tsl::profiler::kThreadpoolListenerStartRegion) {
has_start_region_event = true;
const auto& stat = event.GetStat(StatType::kConsumerId);
EXPECT_TRUE(stat.has_value());
start_region_id = stat->IntOrUintValue();
region_timespan = event.GetTimespan();
} else if (event.Name() == tsl::profiler::kThreadpoolListenerStopRegion) {
has_end_region_event = true;
region_timespan = Timespan::FromEndPoints(region_timespan.begin_ps(),
event.GetTimespan().end_ps());
} else if (event.Name() == "hello") {
traceme_timespan = event.GetTimespan();
}
});
});

EXPECT_TRUE(has_record_event);
EXPECT_TRUE(has_start_region_event);
EXPECT_TRUE(has_end_region_event);

EXPECT_EQ(record_region_id, start_region_id);

EXPECT_TRUE(region_timespan.Includes(traceme_timespan));
}

} // namespace
} // namespace profiler
} // namespace xla

0 comments on commit d8e28bb

Please sign in to comment.