[go: nahoru, domu]

Skip to content

Commit

Permalink
[xla:ffi] Add xla::ffi::ExecutionContext to pass user data to FFI han…
Browse files Browse the repository at this point in the history
…dlers at run time

There are two kinds of data that can be added to execution context:
1. Opaque pointers with a deleter for users that register context types that are can be defined in shared libraries and can't depend on XLA internals
2. Types ExecutionContext::UserData for FFI handlers in the same process

Next step is to plumb execution context all the way up to xla::ExecuteOptions and PjRt

PiperOrigin-RevId: 631993077
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed May 9, 2024
1 parent 28bf747 commit 24a30e3
Show file tree
Hide file tree
Showing 14 changed files with 536 additions and 35 deletions.
30 changes: 30 additions & 0 deletions third_party/xla/xla/ffi/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,39 @@ cc_library(
],
)

cc_library(
name = "execution_context",
srcs = ["execution_context.cc"],
hdrs = ["execution_context.h"],
deps = [
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)

xla_cc_test(
name = "execution_context_test",
srcs = ["execution_context_test.cc"],
deps = [
":execution_context",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@local_tsl//tsl/lib/core:status_test_util",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
"@local_tsl//tsl/platform:test_main",
],
)

cc_library(
name = "ffi",
hdrs = ["ffi.h"],
deps = [
":api",
":execution_context",
"//xla:shape_util",
"//xla:status",
"//xla:types",
Expand All @@ -55,6 +83,7 @@ cc_library(
deps = [
":api",
":call_frame",
":execution_context",
"//xla:status",
"//xla:statusor",
"//xla/ffi/api:c_api",
Expand All @@ -73,6 +102,7 @@ xla_cc_test(
srcs = ["ffi_test.cc"],
deps = [
":call_frame",
":execution_context",
":ffi",
":ffi_api",
"//xla:xla_data_proto_cc",
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/ffi/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ xla_cc_test(
":ffi",
"//xla:xla_data_proto_cc",
"//xla/ffi:call_frame",
"//xla/ffi:execution_context",
"//xla/ffi:ffi_api",
"//xla/stream_executor:device_memory",
"@com_google_absl//absl/log:check",
Expand Down
15 changes: 6 additions & 9 deletions third_party/xla/xla/ffi/api/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,13 @@ XLA_FFI_Error* Ffi::RegisterStaticHandler(const XLA_FFI_Api* api,
std::string_view platform,
XLA_FFI_Handler* handler,
XLA_FFI_Handler_Traits traits) {
// Make copies of string views to guarantee they are null terminated.
std::string name_str(name);
std::string platform_str(platform);

XLA_FFI_Handler_Register_Args args;
args.struct_size = XLA_FFI_Handler_Register_Args_STRUCT_SIZE;
args.priv = nullptr;
args.name = name_str.c_str();
args.platform = platform_str.c_str();
args.name = XLA_FFI_ByteSpan{XLA_FFI_ByteSpan_STRUCT_SIZE, nullptr,
name.data(), name.size()};
args.platform = XLA_FFI_ByteSpan{XLA_FFI_ByteSpan_STRUCT_SIZE, nullptr,
platform.data(), platform.size()};
args.handler = handler;
args.traits = traits;
return api->XLA_FFI_Handler_Register(&args);
Expand Down Expand Up @@ -642,9 +640,8 @@ struct AttrDecoding;
// XLA_FFI_ExecutionContext* ctx);
// }
//
// TODO(ezhulenev): Add an example for decoding opaque data passed together with
// a handler registration (not yet implemented). Today this is only used as
// internal implementation detail of builtin FFI handlers.
// Second template parameter is used to conditionally enable/disable context
// decoding specialization for a given type via SFINAE.
template <typename T>
struct CtxDecoding;

Expand Down
24 changes: 22 additions & 2 deletions third_party/xla/xla/ffi/api/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,8 @@ struct XLA_FFI_Handler_Register_Args {
size_t struct_size;
void* priv;

const char* name; // null terminated
const char* platform; // null terminated
XLA_FFI_ByteSpan name;
XLA_FFI_ByteSpan platform;
XLA_FFI_Handler* handler;
XLA_FFI_Handler_Traits traits;
};
Expand All @@ -336,6 +336,25 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Handler_Register_Args, traits);
typedef XLA_FFI_Error* XLA_FFI_Handler_Register(
XLA_FFI_Handler_Register_Args* args);

//===----------------------------------------------------------------------===//
// ExecutionContext
//===----------------------------------------------------------------------===//

struct XLA_FFI_ExecutionContext_Get_Args {
size_t struct_size;
void* priv;

XLA_FFI_ExecutionContext* ctx;
XLA_FFI_ByteSpan id;
void* data; // out
};

XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_ExecutionContext_Get_Args, data);

// Returns an opaque data from the execution context for a given name.
typedef XLA_FFI_Error* XLA_FFI_ExecutionContext_Get(
XLA_FFI_ExecutionContext_Get_Args* args);

//===----------------------------------------------------------------------===//
// Stream
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -371,6 +390,7 @@ struct XLA_FFI_Api {
_XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Destroy);
_XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Handler_Register);
_XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Stream_Get);
_XLA_FFI_API_STRUCT_FIELD(XLA_FFI_ExecutionContext_Get);
};

#undef _XLA_FFI_API_STRUCT_FIELD
Expand Down
6 changes: 6 additions & 0 deletions third_party/xla/xla/ffi/api/c_api_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ typedef void* XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get(
typedef void* XLA_FFI_INTERNAL_CalledComputation_Get(
XLA_FFI_ExecutionContext* ctx);

// Returns a pointer to the underlying `xla::ffi::ExecutionContext` object which
// allows to access typed user data attached to the execution context.
typedef void* XLA_FFI_INTERNAL_ExecutionContext_Get(
XLA_FFI_ExecutionContext* ctx);

//===----------------------------------------------------------------------===//
// API access
//===----------------------------------------------------------------------===//
Expand All @@ -77,6 +82,7 @@ struct XLA_FFI_InternalApi {
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(
XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get);
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_CalledComputation_Get);
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_ExecutionContext_Get);
};

#undef _XLA_FFI_INTERNAL_API_STRUCT_FIELD
Expand Down
79 changes: 63 additions & 16 deletions third_party/xla/xla/ffi/api/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef XLA_FFI_API_FFI_H_
#define XLA_FFI_API_FFI_H_

#include <string_view>
#ifdef XLA_FFI_FFI_H_
#error Two different XLA FFI implementations cannot be included together
#endif // XLA_FFI_FFI_H_
Expand Down Expand Up @@ -489,6 +490,34 @@ struct ResultEncoding<Error> {
}
};

//===----------------------------------------------------------------------===//
// Error helpers
//===----------------------------------------------------------------------===//

namespace internal {

struct ErrorUtil {
static const char* GetErrorMessage(const XLA_FFI_Api* api,
XLA_FFI_Error* error) {
XLA_FFI_Error_GetMessage_Args args;
args.struct_size = XLA_FFI_Error_GetMessage_Args_STRUCT_SIZE;
args.priv = nullptr;
args.error = error;
api->XLA_FFI_Error_GetMessage(&args);
return args.message;
}

static void DestroyError(const XLA_FFI_Api* api, XLA_FFI_Error* error) {
XLA_FFI_Error_Destroy_Args args;
args.struct_size = XLA_FFI_Error_Destroy_Args_STRUCT_SIZE;
args.priv = nullptr;
args.error = error;
api->XLA_FFI_Error_Destroy(&args);
}
};

} // namespace internal

//===----------------------------------------------------------------------===//
// PlatformStream
//===----------------------------------------------------------------------===//
Expand All @@ -513,30 +542,48 @@ struct CtxDecoding<PlatformStream<T>> {

if (XLA_FFI_Error* error = api->XLA_FFI_Stream_Get(&args); error) {
diagnostic.Emit("Failed to get platform stream: ")
<< GetErrorMessage(api, error);
DestroyError(api, error);
<< internal::ErrorUtil::GetErrorMessage(api, error);
internal::ErrorUtil::DestroyError(api, error);
return std::nullopt;
}

return reinterpret_cast<T>(args.stream);
}
};

static const char* GetErrorMessage(const XLA_FFI_Api* api,
XLA_FFI_Error* error) {
XLA_FFI_Error_GetMessage_Args args;
args.struct_size = XLA_FFI_Error_GetMessage_Args_STRUCT_SIZE;
args.priv = nullptr;
args.error = error;
api->XLA_FFI_Error_GetMessage(&args);
return args.message;
}
//===----------------------------------------------------------------------===//
// UserData
//===----------------------------------------------------------------------===//

static void DestroyError(const XLA_FFI_Api* api, XLA_FFI_Error* error) {
XLA_FFI_Error_Destroy_Args args;
args.struct_size = XLA_FFI_Error_Destroy_Args_STRUCT_SIZE;
// A type tag for automatic decoding user data passed via the execution context.
template <const char* id, typename T>
struct UserData {};

template <const char* id, typename T>
struct CtxDecoding<UserData<id, T>> {
using Type = T*;

static std::optional<Type> Decode(const XLA_FFI_Api* api,
XLA_FFI_ExecutionContext* ctx,
DiagnosticEngine& diagnostic) {
static constexpr std::string_view id_view = {id};

XLA_FFI_ExecutionContext_Get_Args args;
args.struct_size = XLA_FFI_ExecutionContext_Get_Args_STRUCT_SIZE;
args.priv = nullptr;
args.error = error;
api->XLA_FFI_Error_Destroy(&args);
args.ctx = ctx;
args.id = XLA_FFI_ByteSpan{XLA_FFI_ByteSpan_STRUCT_SIZE, nullptr,
id_view.data(), id_view.size()};
args.data = nullptr;

if (XLA_FFI_Error* err = api->XLA_FFI_ExecutionContext_Get(&args); err) {
diagnostic.Emit("Failed to get platform stream: ")
<< internal::ErrorUtil::GetErrorMessage(api, err);
internal::ErrorUtil::DestroyError(api, err);
return std::nullopt;
}

return static_cast<Type>(args.data);
}
};

Expand Down
33 changes: 33 additions & 0 deletions third_party/xla/xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ limitations under the License.

#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <vector>

#include "absl/log/check.h"
#include "absl/status/status.h"
#include "xla/ffi/call_frame.h"
#include "xla/ffi/execution_context.h"
#include "xla/ffi/ffi_api.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/xla_data.pb.h"
Expand Down Expand Up @@ -347,6 +350,36 @@ TEST(FfiTest, PointerAttr) {
TF_ASSERT_OK(status);
}

struct MyData {
std::string str;
};

TEST(FfiTest, UserData) {
static constexpr char kId[] = "my_data";

MyData data{"foo"};
auto deleter = +[](void*) {};

ExecutionContext execution_context;
TF_ASSERT_OK(execution_context.Emplace(kId, &data, deleter));

CallFrameBuilder builder;
auto call_frame = builder.Build();

auto fn = [&](MyData* data) {
EXPECT_EQ(data->str, "foo");
return Error::Success();
};

auto handler = Ffi::Bind().Ctx<UserData<kId, MyData>>().To(fn);

CallOptions options;
options.execution_context = &execution_context;
auto status = Call(*handler, call_frame, options);

TF_ASSERT_OK(status);
}

//===----------------------------------------------------------------------===//
// Performance benchmarks are below.
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 24a30e3

Please sign in to comment.