[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow MaybeSavedModelDirectory to read via custom env. #64761

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/official/envs/linux_x86_cuda
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda
TFCI_BUILD_PIP_PACKAGE_ARGS="--repo_env=WHEEL_NAME=tensorflow"
TFCI_DOCKER_ARGS="--gpus all"
TFCI_LIB_SUFFIX="-gpu-linux-x86_64"
TFCI_WHL_SIZE_LIMIT=580M
TFCI_WHL_SIZE_LIMIT=600M
8 changes: 4 additions & 4 deletions tensorflow/cc/saved_model/loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,16 +535,16 @@ Status LoadSavedModel(const SessionOptions& session_options,
return absl::OkStatus();
}

bool MaybeSavedModelDirectory(const string& export_dir) {
bool MaybeSavedModelDirectory(const string& export_dir, tsl::Env* env) {
const string saved_model_pb_path =
io::JoinPath(export_dir, kSavedModelFilenamePb);
const string saved_model_cpb_path =
io::JoinPath(export_dir, kSavedModelFilenameCpb);
const string saved_model_pbtxt_path =
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
return Env::Default()->FileExists(saved_model_pb_path).ok() ||
Env::Default()->FileExists(saved_model_cpb_path).ok() ||
Env::Default()->FileExists(saved_model_pbtxt_path).ok();
return env->FileExists(saved_model_pb_path).ok() ||
env->FileExists(saved_model_cpb_path).ok() ||
env->FileExists(saved_model_pbtxt_path).ok();
}

} // namespace tensorflow
4 changes: 3 additions & 1 deletion tensorflow/cc/saved_model/loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/public/session.h"
#include "tsl/platform/env.h"

namespace tensorflow {

Expand Down Expand Up @@ -140,7 +141,8 @@ Status LoadSavedModel(const SessionOptions& session_options,
/// the export directory definitely does not contain a SavedModel. If the method
/// returns `true`, the export directory may contain a SavedModel but provides
/// no guarantee that it can be loaded.
bool MaybeSavedModelDirectory(const std::string& export_dir);
bool MaybeSavedModelDirectory(const std::string& export_dir,
tsl::Env* env = tsl::Env::Default());

} // namespace tensorflow

Expand Down
65 changes: 44 additions & 21 deletions third_party/xla/xla/ffi/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,15 @@ class AnyBuffer {

ABSL_ATTRIBUTE_ALWAYS_INLINE size_t size_bytes() const {
if (ABSL_PREDICT_TRUE(primitive_util::IsArrayType(element_type()))) {
return absl::c_accumulate(dimensions(),
primitive_util::ByteWidth(element_type()),
std::multiplies<int64_t>());
return primitive_util::ByteWidth(element_type()) * element_count();
}
return 0;
}

ABSL_ATTRIBUTE_ALWAYS_INLINE size_t element_count() const {
return absl::c_accumulate(dimensions(), int64_t{1}, std::multiplies<>());
}

se::DeviceMemoryBase device_memory() const {
return se::DeviceMemoryBase(untyped_data(), size_bytes());
}
Expand All @@ -112,11 +114,45 @@ class AnyBuffer {
// The dtype and rank are checked at decoding time. If rank is not specified,
// any rank is accepted.
template <PrimitiveType dtype, size_t rank = internal::kDynamicRank>
struct Buffer {
class Buffer {
public:
using Dimensions = AnyBuffer::Dimensions;

se::DeviceMemory<internal::NativeType<dtype>> data;
Dimensions dimensions;
explicit Buffer(absl::Nonnull<const XLA_FFI_Buffer*> buf) : buf_(buf) {
DCHECK(buf_ != nullptr) << "XLA_FFI_Buffer must be non-null";
}

PrimitiveType element_type() const { return dtype; }

void* untyped_data() const { return buf_->data; }

internal::NativeType<dtype>* typed_data() const {
return reinterpret_cast<internal::NativeType<dtype>*>(untyped_data());
}

Dimensions dimensions() const {
return Dimensions(buf_->dims,
rank == internal::kDynamicRank ? buf_->rank : rank);
}

ABSL_ATTRIBUTE_ALWAYS_INLINE size_t size_bytes() const {
if constexpr (primitive_util::IsArrayType(dtype)) {
return primitive_util::ByteWidth(dtype) * element_count();
}
return 0;
}

ABSL_ATTRIBUTE_ALWAYS_INLINE size_t element_count() const {
return absl::c_accumulate(dimensions(), int64_t{1}, std::multiplies<>());
}

se::DeviceMemory<internal::NativeType<dtype>> device_memory() const {
return se::DeviceMemory<internal::NativeType<dtype>>(
se::DeviceMemoryBase(untyped_data(), size_bytes()));
}

private:
const XLA_FFI_Buffer* buf_;
};

// clang-format off
Expand All @@ -127,7 +163,7 @@ template <PrimitiveType dtype> using BufferR3 = Buffer<dtype, 3>;
template <PrimitiveType dtype> using BufferR4 = Buffer<dtype, 4>;
// clang-format on

using Token = BufferR0<PrimitiveType::TOKEN>;
using Token = BufferR0<PrimitiveType::TOKEN>; // NOLINT

namespace internal {

Expand All @@ -148,20 +184,7 @@ ABSL_ATTRIBUTE_ALWAYS_INLINE std::optional<Buffer<dtype, rank>> DecodeBuffer(
}
}

size_t size_bytes = 0;
if constexpr (primitive_util::IsArrayType(dtype)) {
size_bytes = primitive_util::ByteWidth(dtype);
for (int64_t i = 0, r = rank == internal::kDynamicRank ? buf->rank : rank;
i < r; ++i) {
size_bytes *= buf->dims[i];
}
}

Buffer<dtype, rank> buffer;
buffer.data = se::DeviceMemory<NativeType<dtype>>(
se::DeviceMemoryBase(buf->data, size_bytes));
buffer.dimensions = absl::MakeConstSpan(buf->dims, buf->rank);
return buffer;
return Buffer<dtype, rank>(buf);
}

} // namespace internal
Expand Down
30 changes: 15 additions & 15 deletions third_party/xla/xla/ffi/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,9 @@ TEST(FfiTest, TypedAndRankedBufferArgument) {
auto call_frame = builder.Build();

auto fn = [&](BufferR2<PrimitiveType::F32> buffer) {
EXPECT_EQ(buffer.data.opaque(), storage.data());
EXPECT_EQ(buffer.data.ElementCount(), storage.size());
EXPECT_EQ(buffer.dimensions.size(), 2);
EXPECT_EQ(buffer.untyped_data(), storage.data());
EXPECT_EQ(buffer.element_count(), storage.size());
EXPECT_EQ(buffer.dimensions().size(), 2);
return absl::OkStatus();
};

Expand All @@ -554,8 +554,8 @@ TEST(FfiTest, ComplexBufferArgument) {
auto call_frame = builder.Build();

auto fn = [&](BufferR2<PrimitiveType::C64> buffer) {
EXPECT_EQ(buffer.data.opaque(), storage.data());
EXPECT_EQ(buffer.dimensions.size(), 2);
EXPECT_EQ(buffer.untyped_data(), storage.data());
EXPECT_EQ(buffer.dimensions().size(), 2);
return absl::OkStatus();
};

Expand All @@ -571,8 +571,8 @@ TEST(FfiTest, TokenArgument) {
auto call_frame = builder.Build();

auto fn = [&](Token tok) {
EXPECT_EQ(tok.data.opaque(), nullptr);
EXPECT_EQ(tok.dimensions.size(), 0);
EXPECT_EQ(tok.untyped_data(), nullptr);
EXPECT_EQ(tok.dimensions().size(), 0);
return absl::OkStatus();
};

Expand Down Expand Up @@ -720,21 +720,21 @@ TEST(FfiTest, UpdateBufferArgumentsAndResults) {
// `fn0` expects argument to be `memory0` and result to be `memory1`.
auto fn0 = [&](BufferR2<PrimitiveType::F32> arg,
Result<BufferR2<PrimitiveType::F32>> ret, int32_t n) {
EXPECT_EQ(arg.data.opaque(), storage0.data());
EXPECT_EQ(ret->data.opaque(), storage1.data());
EXPECT_EQ(arg.dimensions, dims);
EXPECT_EQ(ret->dimensions, dims);
EXPECT_EQ(arg.untyped_data(), storage0.data());
EXPECT_EQ(ret->untyped_data(), storage1.data());
EXPECT_EQ(arg.dimensions(), dims);
EXPECT_EQ(ret->dimensions(), dims);
EXPECT_EQ(n, 42);
return absl::OkStatus();
};

// `fn1` expects argument to be `memory1` and result to be `memory0`.
auto fn1 = [&](BufferR2<PrimitiveType::F32> arg,
Result<BufferR2<PrimitiveType::F32>> ret, int32_t n) {
EXPECT_EQ(arg.data.opaque(), storage1.data());
EXPECT_EQ(ret->data.opaque(), storage0.data());
EXPECT_EQ(arg.dimensions, dims);
EXPECT_EQ(ret->dimensions, dims);
EXPECT_EQ(arg.untyped_data(), storage1.data());
EXPECT_EQ(ret->untyped_data(), storage0.data());
EXPECT_EQ(arg.dimensions(), dims);
EXPECT_EQ(ret->dimensions(), dims);
EXPECT_EQ(n, 42);
return absl::OkStatus();
};
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/pjrt/cpu/cpu_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,8 @@ struct MemsetValue {
static absl::Status MemsetFromValue(
ffi::Result<ffi::BufferR1<PrimitiveType::F32>> result,
MemsetValue* memset_value) {
for (size_t i = 0; i < result->dimensions.at(0); ++i) {
result->data.base()[i] = memset_value->value;
for (size_t i = 0; i < result->element_count(); ++i) {
result->typed_data()[i] = memset_value->value;
}
return absl::OkStatus();
}
Expand Down
5 changes: 3 additions & 2 deletions third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/stream.h"
#include "xla/test.h"
#include "xla/tests/literal_test_util.h"
Expand Down Expand Up @@ -318,8 +319,8 @@ static absl::Status MemsetFromValue(
uint32_t pattern;
std::memcpy(&pattern, &memset_value->value, sizeof(pattern));

se::DeviceMemoryBase base = result->data;
return stream->Memset32(&base, pattern, result->data.size());
se::DeviceMemoryBase base = result->device_memory();
return stream->Memset32(&base, pattern, base.size());
}

XLA_FFI_DEFINE_HANDLER(kMemsetFromValue, MemsetFromValue,
Expand Down
Loading