[go: nahoru, domu]

Skip to content

Commit

Permalink
[tsl:concurrency] Update AsyncValueRef documentation and add more imp…
Browse files Browse the repository at this point in the history
…licit constructors

Explicitly state in AsyncValueRef<T> documentation that it's an async version of absl::StatusOr<T> and that's how users should think about it.

Also add a test to verify that absl::StatusOr<absl::Status> and AsyncValueRef<absl::Status> have consistent implicit construction behavior.

PiperOrigin-RevId: 642517408
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Jun 12, 2024
1 parent 743450f commit a6fc82d
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 67 deletions.
31 changes: 7 additions & 24 deletions third_party/xla/xla/tsl/concurrency/async_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@ limitations under the License.
#include <cstdint>
#include <iostream>
#include <memory>
#include <string_view>
#include <type_traits>
#include <utility>

#include "absl/base/attributes.h"
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -494,7 +492,7 @@ void RunWhenReady(absl::Span<RCReference<AsyncValue> const> values,
//===----------------------------------------------------------------------===//

// Traits for customizing AsyncValue behavior for different payload types.
struct AsyncValueTraits {
struct AsyncPayload {
// Under the normal behavior, if an AsyncValue is in kConstructed state (i.e.
// the payload data is constructed), it will destruct the payload data when
// the AsyncValue enters the error state (e.g., on AsyncValue::SetError()).
Expand All @@ -504,34 +502,19 @@ struct AsyncValueTraits {
// the payload value if constructed, will be kept valid when the AsyncValue
// goes into the error state.
struct KeepAsyncValuePayloadOnError {};

// If async value payload inherits from this type, we allow implicit
// AsyncValueRef<T> construction from `absl::Status` as an error async value.
struct AllowImplicitStatusConstruction {};
};

namespace internal {

// TODO(ezhulenev): Remove backward-compatible alias after migrating all users.
using KeepAsyncValuePayloadOnError =
AsyncValueTraits::KeepAsyncValuePayloadOnError;
using KeepAsyncValuePayloadOnError = AsyncPayload::KeepAsyncValuePayloadOnError;

// Subclass for storing the concrete payload of the AsyncValue.
//
// Async value itself is a container that either holds `absl::Status` (in error
// state) or a concrete value of type `T` (in concrete state). Async value that
// holds a status is typically a bad idea, and should be expressed as a plain
// async value.
//
// Example:
// - Prefer `AsyncValueRef<Chain>` to `AsyncValueRef<absl::Status>`.
// Instead of a `Chain` it can be any other empty struct to signal that only
// the potential error is important.
//
// - Prefer `AsyncValueRef<T>` to `AsyncValueRef<absl::StatusOr<T>>`.
// Similar to the `absl::StatusOr<T>` async value will be either in error
// state holding an `absl::Status` error, or in concrete state holding a
// value of type `T`.
//
// Subclass for storing the payload of the AsyncValue
// holds an `absl::Status` or `absl::StatusOr<T>` is typically a bad idea, and
// should be expressed as a plain async value of type `T`.
template <typename T>
class ConcreteAsyncValue : public AsyncValue {
public:
Expand Down Expand Up @@ -727,7 +710,7 @@ class ConcreteAsyncValue : public AsyncValue {
};

using DataStoreT = std::conditional_t<
std::is_base_of_v<AsyncValueTraits::KeepAsyncValuePayloadOnError, T>,
std::is_base_of_v<AsyncPayload::KeepAsyncValuePayloadOnError, T>,
DataAndError, DataOrError>;
alignas(AsyncValue::kDataOffset) DataStoreT data_store_;

Expand Down
63 changes: 42 additions & 21 deletions third_party/xla/xla/tsl/concurrency/async_value_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,42 +63,63 @@ AsyncValueRef<T> MakeConstructedAsyncValueRef(Args&&... args);
template <typename T, typename... Args>
AsyncValueRef<T> MakeAvailableAsyncValueRef(Args&&... args);

// RCReference<AsyncValue> wrapper.
// AsyncValueRef<T> is an asynchronous container for a payload of type `T` or an
// error of type `absl::Status`. It is similar to an `absl::StatusOr<T>`, but
// does not require immediate value or error to be constructed. It is a promise
// that at some point in the future it will become concrete and will hold a
// payload of type `T` or an error of type `absl::Status`.
//
// AsyncValueRef<T> is an alias for RCReference<AsyncValue> that carries payload
// type information. The user does not need to pass the payload data type to
// get() or emplace().
// - Prefer `AsyncValueRef<Chain>` to `AsyncValueRef<absl::Status>`.
// Instead of a `Chain` it can be any other empty struct to signal that only
// the potential error is important.
//
// Like RCReference<AsyncValue>, it represents one reference on the underlying
// AsyncValue. When a callee returns an AsyncValueRef to a caller, the callee
// also transfers their ownership of a reference on the underlying AsyncValue.
// - Prefer `AsyncValueRef<T>` to `AsyncValueRef<absl::StatusOr<T>>`.
// Similar to the `absl::StatusOr<T>` async value will be either in error
// state holding an `absl::Status` error, or in concrete state holding a
// value of type `T`.
template <typename T>
class AsyncValueRef {
public:
// AsyncValueRef<T>::value_type
using value_type = T;

AsyncValueRef() = default;
AsyncValueRef(std::nullptr_t) {} // NOLINT

AsyncValueRef(const AsyncValueRef&) = default;
AsyncValueRef& operator=(const AsyncValueRef&) = default;

AsyncValueRef(AsyncValueRef&&) = default;
AsyncValueRef& operator=(AsyncValueRef&&) = default;

explicit AsyncValueRef(RCReference<AsyncValue> value)
: value_(std::move(value)) {}

// Support implicit conversion from AsyncValueRef<Derived> to
// AsyncValueRef<Base>.
template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValueRef(AsyncValueRef<Derived>&& u) // NOLINT
: value_(u.ReleaseRCRef()) {}
template <typename Derived,
internal::DerivedFrom<Derived, AsyncValue>* = nullptr>
explicit AsyncValueRef(RCReference<Derived> value)
: AsyncValueRef(RCReference<AsyncValue>(std::move(value))) {}

// Support implicit construction from nullptr to empty async value ref.
AsyncValueRef(std::nullptr_t) {} // NOLINT

// Support implicit construction from immediate value.
AsyncValueRef(T value) // NOLINT
: AsyncValueRef(MakeAvailableAsyncValueRef<T>(std::move(value))) {}

// Support implicit construction from absl::Status.
// Support implicit construction from immediate Status error convertible to
// absl::Status (only if payload type is not absl::Status, otherwise we
// always pass absl::Status to payload constructor for consistency with
// absl::StatusOr<absl::Status>).
template <typename Status,
std::enable_if_t<std::is_same_v<Status, absl::Status>>* = nullptr>
AsyncValueRef(Status status) // NOLINT
: AsyncValueRef(MakeErrorAsyncValueRef(std::move(status))) {
static_assert(
std::is_base_of_v<AsyncValueTraits::AllowImplicitStatusConstruction, T>,
"Payload must explicitly opt-in implicit status construction");
}
std::enable_if_t<std::is_convertible_v<Status, absl::Status> &&
!std::is_same_v<T, absl::Status>>* = nullptr>
AsyncValueRef(Status&& status) // NOLINT
: AsyncValueRef(MakeErrorAsyncValueRef(std::forward<Status>(status))) {}

// Support implicit conversion from an async value of a derived type.
template <typename Derived, internal::DerivedFrom<Derived, T>* = nullptr>
AsyncValueRef(AsyncValueRef<Derived> derived) // NOLINT
: value_(derived.ReleaseRCRef()) {}

// Support implicit construction from RCReference<ErrorAsyncValue>.
AsyncValueRef(RCReference<ErrorAsyncValue> value) // NOLINT
Expand Down
96 changes: 91 additions & 5 deletions third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "xla/tsl/concurrency/async_value_ref.h"

#include <any>
#include <atomic>
#include <cstddef>
#include <cstdint>
Expand Down Expand Up @@ -42,16 +43,101 @@ class WrappedInt32 {

constexpr int32_t kTestValue = 42;

TEST(AsyncValueRefTest, ImplicitStatusConversion) {
struct Empty : public AsyncValueTraits::AllowImplicitStatusConstruction {};
TEST(AsyncValueRefTest, MakeUnconstructedStatusOrOfAny) {
auto value = MakeUnconstructedAsyncValueRef<absl::StatusOr<std::any>>();
EXPECT_TRUE(value.IsUnavailable());
}

TEST(AsyncValueRefTest, MakeUnconstructedStatusOr) {
auto value = MakeUnconstructedAsyncValueRef<absl::StatusOr<int32_t>>();
EXPECT_TRUE(value.IsUnavailable());
}

TEST(AsyncValueRefTest, MakeConstructedStatusOr) {
auto value = MakeConstructedAsyncValueRef<absl::StatusOr<int32_t>>(42);
EXPECT_TRUE(value.IsUnavailable());
}

TEST(AsyncValueRefTest, MakeAvailableStatusOr) {
auto value = MakeAvailableAsyncValueRef<absl::StatusOr<int32_t>>(42);
EXPECT_TRUE(value.IsAvailable());
EXPECT_EQ(**value, 42);
}

TEST(AsyncValueRefTest, ImplicitValueConversion) {
auto payload = []() -> AsyncValueRef<WrappedInt32> {
return WrappedInt32{42};
}();

EXPECT_TRUE(payload.IsConcrete());
EXPECT_EQ(payload->value(), 42);
}

auto error = []() -> AsyncValueRef<Empty> {
return absl::InternalError("error");
TEST(AsyncValueRefTest, ImplicitStatusConversion) {
auto error = []() -> AsyncValueRef<WrappedInt32> {
return absl::InternalError("Error");
}();

EXPECT_TRUE(error.IsAvailable());
EXPECT_TRUE(error.IsError());
EXPECT_EQ(error.GetError(), absl::InternalError("error"));
EXPECT_EQ(error.GetError(), absl::InternalError("Error"));
}

TEST(AsyncValueRefTest, ImplicitStatusConversionWithStatusPayload) {
auto status = []() -> absl::StatusOr<absl::Status> {
return absl::InternalError("Error");
}();

auto error = []() -> AsyncValueRef<absl::Status> {
return absl::InternalError("Error");
}();

// Check that AsyncValueRef<absl::Status> behavior is consistent with
// absl::StatusOr<absl::Status> for implicit error conversion.

ASSERT_TRUE(status.ok());
ASSERT_EQ(*status, absl::InternalError("Error"));

EXPECT_TRUE(error.IsConcrete());
EXPECT_EQ(error.get(), absl::InternalError("Error"));
}

TEST(AsyncValueRefTest, ImplicitStatusConversionWithStatusOrPayload) {
auto status = []() -> absl::StatusOr<absl::StatusOr<int32_t>> {
return absl::StatusOr<int32_t>(absl::InternalError("Error"));
}();

auto error = []() -> AsyncValueRef<absl::StatusOr<int32_t>> {
return absl::StatusOr<int32_t>(absl::InternalError("Error"));
}();

// Check that AsyncValueRef<absl::StatusOr<T>> behavior is consistent with
// absl::StatusOr<absl::StatusOr<T>> for implicit error conversion.

ASSERT_TRUE(status.ok());
ASSERT_EQ(status->status(), absl::InternalError("Error"));

EXPECT_TRUE(error.IsConcrete());
EXPECT_EQ(error->status(), absl::InternalError("Error"));
}

TEST(AsyncValueRefTest, ImplicitStatusConversionWithStatusOrPayloadAndStatus) {
auto status = []() -> absl::StatusOr<absl::StatusOr<int32_t>> {
return absl::InternalError("Error");
}();

auto error = []() -> AsyncValueRef<absl::StatusOr<int32_t>> {
return absl::InternalError("Error");
}();

// Check that AsyncValueRef<absl::StatusOr<T>> behavior is consistent with
// absl::StatusOr<absl::StatusOr<T>> for implicit error conversion.

ASSERT_FALSE(status.ok());
ASSERT_EQ(status.status(), absl::InternalError("Error"));

EXPECT_TRUE(error.IsError());
EXPECT_EQ(error.GetError(), absl::InternalError("Error"));
}

TEST(AsyncValueRefTest, ValueCheck) {
Expand Down
7 changes: 2 additions & 5 deletions third_party/xla/xla/tsl/concurrency/chain.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@ limitations under the License.
#ifndef XLA_TSL_CONCURRENCY_CHAIN_H_
#define XLA_TSL_CONCURRENCY_CHAIN_H_

#include "xla/tsl/concurrency/async_value.h"

namespace tsl {

// An empty struct to signal completion of asynchronous events. We explicitly
// enable implicit conversion from absl::Status to asynchronous errors.
class Chain : public AsyncValueTraits::AllowImplicitStatusConstruction {};
// An empty struct to signal completion of asynchronous events.
class Chain {};

} // namespace tsl

Expand Down
12 changes: 0 additions & 12 deletions third_party/xla/xla/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,6 @@ limitations under the License.
#include "tsl/platform/logging.h"
#include "tsl/platform/ml_dtypes.h"

namespace tsl {
// Forward declare AsyncValueRef to enable implicit conversion from XLA errors
// to error async values.
template <typename T>
class AsyncValueRef;
} // namespace tsl

namespace xla {

// Converts the unsigned integer n into a mixed-radix representation with the
Expand Down Expand Up @@ -253,11 +246,6 @@ absl::Status AppendStatus(absl::Status prior, absl::string_view context);
#define XLA_ERROR_WITH_STRFORMAT_AND_BACKTRACE_SUFFIX(error_type) \
/* NOLINTNEXTLINE(google-explicit-constructor) */ \
operator absl::Status() const { return status; } \
/* NOLINTNEXTLINE(google-explicit-constructor) */ \
template <typename T> \
operator tsl::AsyncValueRef<T>() const { \
return status; \
} \
} \
; \
/*Deduction guide to make variadic arguments play nice with default */ \
Expand Down

0 comments on commit a6fc82d

Please sign in to comment.