[go: nahoru, domu]

Skip to content

Commit

Permalink
[xla:cpu] NFC: Extern LogicalIdThunk template in a header file
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 647270647
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Jun 27, 2024
1 parent bd15b83 commit 5e87cc0
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 33 deletions.
55 changes: 29 additions & 26 deletions third_party/xla/xla/service/cpu/runtime/logical_id_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ limitations under the License.
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
namespace xla::cpu::internal {

static Thunk::Kind ToThunkKind(LogicalIdKind logical_id_kind) {
switch (logical_id_kind) {
Expand All @@ -45,41 +45,42 @@ static Thunk::Kind ToThunkKind(LogicalIdKind logical_id_kind) {
}
}

template <LogicalIdKind type>
absl::StatusOr<std::unique_ptr<LogicalIdThunk<type>>>
LogicalIdThunk<type>::Create(Info info,
BufferAllocation::Slice logical_id_buffer) {
template <LogicalIdKind logical_id_kind>
absl::StatusOr<std::unique_ptr<LogicalIdThunk<logical_id_kind>>>
LogicalIdThunk<logical_id_kind>::Create(
Info info, BufferAllocation::Slice logical_id_buffer) {
return absl::WrapUnique(
new LogicalIdThunk(std::move(info), logical_id_buffer));
}

template <LogicalIdKind type>
LogicalIdThunk<type>::LogicalIdThunk(Info info,
BufferAllocation::Slice logical_id_buffer)
: Thunk(ToThunkKind(type), info), logical_id_buffer_(logical_id_buffer) {}
template <LogicalIdKind logical_id_kind>
LogicalIdThunk<logical_id_kind>::LogicalIdThunk(
Info info, BufferAllocation::Slice logical_id_buffer)
: Thunk(ToThunkKind(logical_id_kind), info),
logical_id_buffer_(logical_id_buffer) {}

template <LogicalIdKind type>
template <LogicalIdKind logical_id_kind>
static constexpr auto ToString() {
if constexpr (type == LogicalIdKind::kPartitionId) {
if constexpr (logical_id_kind == LogicalIdKind::kPartitionId) {
return "Partition";
} else if constexpr (type == LogicalIdKind::kReplicaId) {
} else if constexpr (logical_id_kind == LogicalIdKind::kReplicaId) {
return "Replica";
}
}

template <LogicalIdKind type>
absl::StatusOr<int32_t> LogicalIdThunk<type>::GetIdForDevice(
template <LogicalIdKind logical_id_kind>
absl::StatusOr<int32_t> LogicalIdThunk<logical_id_kind>::GetIdForDevice(
const DeviceAssignment* device_assignment, GlobalDeviceId device_id) const {
if constexpr (type == LogicalIdKind::kPartitionId) {
if constexpr (logical_id_kind == LogicalIdKind::kPartitionId) {
return device_assignment->PartitionIdForDevice(device_id);
} else if constexpr (type == LogicalIdKind::kReplicaId) {
} else if constexpr (logical_id_kind == LogicalIdKind::kReplicaId) {
return device_assignment->ReplicaIdForDevice(device_id);
}
}

template <LogicalIdKind type>
tsl::AsyncValueRef<typename LogicalIdThunk<type>::ExecuteEvent>
LogicalIdThunk<type>::Execute(const ExecuteParams& params) {
template <LogicalIdKind logical_id_kind>
tsl::AsyncValueRef<typename LogicalIdThunk<logical_id_kind>::ExecuteEvent>
LogicalIdThunk<logical_id_kind>::Execute(const ExecuteParams& params) {
tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); });

TF_ASSIGN_OR_RETURN(
Expand All @@ -90,14 +91,15 @@ LogicalIdThunk<type>::Execute(const ExecuteParams& params) {
<< "Logical id buffer must be able to fit logical id value";

TF_RET_CHECK(params.collective_params)
<< ToString<type>() << " id requires collective params";
<< ToString<logical_id_kind>() << " id requires collective params";

TF_ASSIGN_OR_RETURN(
int32_t logical_id,
GetIdForDevice(params.collective_params->device_assignment,
params.collective_params->global_device_id));

VLOG(3) << absl::StreamFormat("%s id: %d", ToString<type>(), logical_id);
VLOG(3) << absl::StreamFormat("%s id: %d", ToString<logical_id_kind>(),
logical_id);
VLOG(3) << absl::StreamFormat(" logical_id: slice %s (%p)",
logical_id_buffer_.ToString(),
logical_id_data.opaque());
Expand All @@ -106,15 +108,16 @@ LogicalIdThunk<type>::Execute(const ExecuteParams& params) {
return OkExecuteEvent();
}

template <LogicalIdKind type>
using BufferUses = typename LogicalIdThunk<type>::BufferUses;
template <LogicalIdKind logical_id_kind>
using BufferUses = typename LogicalIdThunk<logical_id_kind>::BufferUses;

template <LogicalIdKind type>
BufferUses<type> LogicalIdThunk<type>::buffer_uses() const {
template <LogicalIdKind logical_id_kind>
BufferUses<logical_id_kind> LogicalIdThunk<logical_id_kind>::buffer_uses()
const {
return {BufferUse::Write(logical_id_buffer_)};
}

template class LogicalIdThunk<LogicalIdKind::kReplicaId>;
template class LogicalIdThunk<LogicalIdKind::kPartitionId>;

} // namespace xla::cpu
} // namespace xla::cpu::internal
15 changes: 11 additions & 4 deletions third_party/xla/xla/service/cpu/runtime/logical_id_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ limitations under the License.

namespace xla::cpu {

namespace internal {
enum class LogicalIdKind {
kPartitionId,
kReplicaId,
};

template <LogicalIdKind type>
template <LogicalIdKind logical_id_kind>
class LogicalIdThunk : public Thunk {
public:
static absl::StatusOr<std::unique_ptr<LogicalIdThunk>> Create(
Expand All @@ -53,11 +54,17 @@ class LogicalIdThunk : public Thunk {
BufferAllocation::Slice logical_id_buffer_;
};

class ReplicaIdThunk final : public LogicalIdThunk<LogicalIdKind::kReplicaId> {
};
// Template is defined and explicitly instantiated in logical_id_thunk.cc.
extern template class LogicalIdThunk<LogicalIdKind::kReplicaId>;
extern template class LogicalIdThunk<LogicalIdKind::kPartitionId>;

} // namespace internal

class ReplicaIdThunk final
: public internal::LogicalIdThunk<internal::LogicalIdKind::kReplicaId> {};

class PartitionIdThunk final
: public LogicalIdThunk<LogicalIdKind::kPartitionId> {};
: public internal::LogicalIdThunk<internal::LogicalIdKind::kPartitionId> {};

} // namespace xla::cpu

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "xla/service/cpu/runtime/logical_id_thunk.h"

#include <cstdint>
#include <limits>
#include <string>
#include <vector>

Expand Down Expand Up @@ -51,7 +52,7 @@ absl::StatusOr<DeviceAssignment> CreateDeviceAssignment(
}

TEST(LogicalIdThunkTest, GetReplicaId) {
std::vector<int32_t> dst(1, -1);
std::vector<int32_t> dst(1, std::numeric_limits<int32_t>::min());

std::vector<MaybeOwningDeviceMemory> buffers;
buffers.emplace_back(se::DeviceMemoryBase(dst.data(), sizeof(int32_t)));
Expand Down Expand Up @@ -86,7 +87,7 @@ TEST(LogicalIdThunkTest, GetReplicaId) {
}

TEST(LogicalIdThunkTest, GetPartitionId) {
std::vector<int32_t> dst(2, -1);
std::vector<int32_t> dst(2, std::numeric_limits<int32_t>::min());

std::vector<MaybeOwningDeviceMemory> buffers;
static constexpr auto kDataSize = 2 * sizeof(int32_t);
Expand Down Expand Up @@ -119,7 +120,7 @@ TEST(LogicalIdThunkTest, GetPartitionId) {
tsl::BlockUntilReady(execute_event);
ASSERT_FALSE(execute_event.IsError());

EXPECT_EQ(dst[0], -1);
EXPECT_EQ(dst[0], std::numeric_limits<int32_t>::min());
EXPECT_EQ(dst[1], 0);
}

Expand Down

0 comments on commit 5e87cc0

Please sign in to comment.