diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index b34c1bb532aab1..acddce1a78e1b2 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -304,6 +304,7 @@ cc_library( "//tensorflow/core/platform:status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_tsl//tsl/profiler/backends/cpu:threadpool_listener_state", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD index 087a2fc9d34c6b..9c3f436bf65aa2 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD @@ -468,6 +468,7 @@ cc_library( "//tsl/platform:strcat", "//tsl/platform:stringpiece", "//tsl/platform:types", + "//tsl/profiler/backends/cpu:threadpool_listener_state", ], alwayslink = True, ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/tracing_impl.h b/third_party/xla/third_party/tsl/tsl/platform/default/tracing_impl.h index 8e06e4f60e8ae5..c254e432c8b391 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/tracing_impl.h +++ b/third_party/xla/third_party/tsl/tsl/platform/default/tracing_impl.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_DEFAULT_TRACING_IMPL_H_ #define TENSORFLOW_TSL_PLATFORM_DEFAULT_TRACING_IMPL_H_ +#ifndef IS_MOBILE_PLATFORM +#include "tsl/profiler/backends/cpu/threadpool_listener_state.h" +#endif // Stub implementations of tracing functionality. // Definitions that do nothing for platforms that don't have underlying thread @@ -33,7 +36,13 @@ limitations under the License. namespace tsl { namespace tracing { -inline bool EventCollector::IsEnabled() { return false; } +inline bool EventCollector::IsEnabled() { +#ifndef IS_MOBILE_PLATFORM + return tsl::profiler::threadpool_listener::IsEnabled(); +#else + return false; +#endif +} } // namespace tracing } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/BUILD index 25c75b1353d8e2..84fbfecfa0d1ee 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/backends/cpu/BUILD @@ -160,7 +160,5 @@ cc_library( name = "threadpool_listener_state", srcs = ["threadpool_listener_state.cc"], hdrs = ["threadpool_listener_state.h"], - visibility = internal_visibility([ - "//tsl/platform:__subpackages__", - ]), + visibility = ["//visibility:public"], ) diff --git a/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc b/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc index 881f46e50837ff..886569b8338e6e 100644 --- a/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc +++ b/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc @@ -14,19 +14,25 @@ limitations under the License. ==============================================================================*/ #include "xla/backends/profiler/cpu/host_tracer.h" +#include #include #include #include #include +#include #include "absl/types/optional.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/blocking_counter.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" +#include "tsl/platform/threadpool.h" #include "tsl/platform/types.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/utils/tf_xplane_visitor.h" +#include "tsl/profiler/utils/timespan.h" #include "tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/utils/xplane_visitor.h" @@ -37,8 +43,11 @@ namespace { using ::tsl::Env; using ::tsl::Thread; using ::tsl::ThreadOptions; +using ::tsl::profiler::StatType; +using ::tsl::profiler::Timespan; using ::tsl::profiler::TraceMe; using ::tsl::profiler::XEventVisitor; +using ::tsl::profiler::XLineVisitor; using ::tsl::profiler::XPlaneVisitor; using ::tsl::profiler::XStatVisitor; @@ -153,6 +162,67 @@ TEST(HostTracerTest, CollectsTraceMeEventsAsXSpace) { EXPECT_EQ(e6.DisplayName(), "Iterator::ParallelMap"); } +TEST(HostTracerTest, CollectEventsFromThreadPool) { + tsl::thread::ThreadPool thread_pool(/*env=*/Env::Default(), + /*name=*/"HostTracerTest", + /*num_threads=*/1); + tsl::BlockingCounter counter(1); + auto tracer = CreateHostTracer({}); + TF_EXPECT_OK(tracer->Start()); + thread_pool.Schedule([&counter] { + TraceMe traceme("hello"); + counter.DecrementCount(); + }); + counter.Wait(); + TF_EXPECT_OK(tracer->Stop()); + tensorflow::profiler::XSpace space; + TF_EXPECT_OK(tracer->CollectData(&space)); + + EXPECT_THAT(space.planes(), testing::SizeIs(1)); + XPlaneVisitor xplane = tsl::profiler::CreateTfXPlaneVisitor(&space.planes(0)); + + bool has_record_event = false; + bool has_start_region_event = false; + bool has_end_region_event = false; + int64_t record_region_id = 0; + int64_t start_region_id = 0; + + Timespan region_timespan; + Timespan traceme_timespan; + + xplane.ForEachLine([&](const XLineVisitor& line) { + line.ForEachEvent([&](const XEventVisitor& event) { + if (event.Name() == tsl::profiler::kThreadpoolListenerRecord) { + has_record_event = true; + const auto& stat = event.GetStat(StatType::kProducerId); + EXPECT_TRUE(stat.has_value()); + record_region_id = stat->IntOrUintValue(); + } else if (event.Name() == + tsl::profiler::kThreadpoolListenerStartRegion) { + has_start_region_event = true; + const auto& stat = event.GetStat(StatType::kConsumerId); + EXPECT_TRUE(stat.has_value()); + start_region_id = stat->IntOrUintValue(); + region_timespan = event.GetTimespan(); + } else if (event.Name() == tsl::profiler::kThreadpoolListenerStopRegion) { + has_end_region_event = true; + region_timespan = Timespan::FromEndPoints(region_timespan.begin_ps(), + event.GetTimespan().end_ps()); + } else if (event.Name() == "hello") { + traceme_timespan = event.GetTimespan(); + } + }); + }); + + EXPECT_TRUE(has_record_event); + EXPECT_TRUE(has_start_region_event); + EXPECT_TRUE(has_end_region_event); + + EXPECT_EQ(record_region_id, start_region_id); + + EXPECT_TRUE(region_timespan.Includes(traceme_timespan)); +} + } // namespace } // namespace profiler } // namespace xla