[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:FFI] Add a BufferShape alias.
Browse files Browse the repository at this point in the history
This change introduces a `BufferShape` alias for custom call users' convenience. This new alias eliminates the need for verbose code such as `using Shape = decltype(Buffer<PrimitiveType::F32>::dimensions)`.

PiperOrigin-RevId: 632119253
  • Loading branch information
Adam-Banas authored and tensorflower-gardener committed May 9, 2024
1 parent d84d042 commit 5d4d94c
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions third_party/xla/xla/ffi/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ struct CalledComputation {}; // binds `HloComputation*`
//===----------------------------------------------------------------------===//

struct BufferBase {
using Shape = absl::Span<const int64_t>;

PrimitiveType dtype;
se::DeviceMemoryBase data;
absl::Span<const int64_t> dimensions;
Shape dimensions;
};

namespace internal {
Expand All @@ -72,8 +74,10 @@ using NativeType = typename primitive_util::PrimitiveTypeToNative<dtype>::type;

template <PrimitiveType dtype, size_t rank = internal::kDynamicRank>
struct Buffer {
using Shape = BufferBase::Shape;

se::DeviceMemory<internal::NativeType<dtype>> data;
absl::Span<const int64_t> dimensions;
Shape dimensions;
};

// clang-format off
Expand Down

0 comments on commit 5d4d94c

Please sign in to comment.