[go: nahoru, domu]

Skip to content

Commit

Permalink
Allow MaybeSavedModelDirectory to read via custom env.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617310034
  • Loading branch information
martinwicke authored and tensorflower-gardener committed Jul 3, 2024
1 parent 279583c commit 4aea788
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 102 deletions.
2 changes: 1 addition & 1 deletion ci/official/envs/linux_x86_cuda
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_cuda
TFCI_BUILD_PIP_PACKAGE_ARGS="--repo_env=WHEEL_NAME=tensorflow"
TFCI_DOCKER_ARGS="--gpus all"
TFCI_LIB_SUFFIX="-gpu-linux-x86_64"
TFCI_WHL_SIZE_LIMIT=580M
TFCI_WHL_SIZE_LIMIT=600M
8 changes: 4 additions & 4 deletions tensorflow/cc/saved_model/loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,16 +535,16 @@ Status LoadSavedModel(const SessionOptions& session_options,
return absl::OkStatus();
}

bool MaybeSavedModelDirectory(const string& export_dir) {
bool MaybeSavedModelDirectory(const string& export_dir, tsl::Env* env) {
const string saved_model_pb_path =
io::JoinPath(export_dir, kSavedModelFilenamePb);
const string saved_model_cpb_path =
io::JoinPath(export_dir, kSavedModelFilenameCpb);
const string saved_model_pbtxt_path =
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
return Env::Default()->FileExists(saved_model_pb_path).ok() ||
Env::Default()->FileExists(saved_model_cpb_path).ok() ||
Env::Default()->FileExists(saved_model_pbtxt_path).ok();
return env->FileExists(saved_model_pb_path).ok() ||
env->FileExists(saved_model_cpb_path).ok() ||
env->FileExists(saved_model_pbtxt_path).ok();
}

} // namespace tensorflow
4 changes: 3 additions & 1 deletion tensorflow/cc/saved_model/loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/public/session.h"
#include "tsl/platform/env.h"

namespace tensorflow {

Expand Down Expand Up @@ -140,7 +141,8 @@ Status LoadSavedModel(const SessionOptions& session_options,
/// the export directory definitely does not contain a SavedModel. If the method
/// returns `true`, the export directory may contain a SavedModel but provides
/// no guarantee that it can be loaded.
bool MaybeSavedModelDirectory(const std::string& export_dir);
bool MaybeSavedModelDirectory(const std::string& export_dir,
tsl::Env* env = tsl::Env::Default());

} // namespace tensorflow

Expand Down
65 changes: 44 additions & 21 deletions third_party/xla/xla/ffi/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,15 @@ class AnyBuffer {

ABSL_ATTRIBUTE_ALWAYS_INLINE size_t size_bytes() const {
if (ABSL_PREDICT_TRUE(primitive_util::IsArrayType(element_type()))) {
return absl::c_accumulate(dimensions(),
primitive_util::ByteWidth(element_type()),
std::multiplies<int64_t>());
return primitive_util::ByteWidth(element_type()) * element_count();
}
return 0;
}

ABSL_ATTRIBUTE_ALWAYS_INLINE size_t element_count() const {
return absl::c_accumulate(dimensions(), int64_t{1}, std::multiplies<>());
}

se::DeviceMemoryBase device_memory() const {
return se::DeviceMemoryBase(untyped_data(), size_bytes());
}
Expand All @@ -112,11 +114,45 @@ class AnyBuffer {
// The dtype and rank are checked at decoding time. If rank is not specified,
// any rank is accepted.
template <PrimitiveType dtype, size_t rank = internal::kDynamicRank>
struct Buffer {
class Buffer {
public:
using Dimensions = AnyBuffer::Dimensions;

se::DeviceMemory<internal::NativeType<dtype>> data;
Dimensions dimensions;
explicit Buffer(absl::Nonnull<const XLA_FFI_Buffer*> buf) : buf_(buf) {
DCHECK(buf_ != nullptr) << "XLA_FFI_Buffer must be non-null";
}

PrimitiveType element_type() const { return dtype; }

void* untyped_data() const { return buf_->data; }

internal::NativeType<dtype>* typed_data() const {
return reinterpret_cast<internal::NativeType<dtype>*>(untyped_data());
}

Dimensions dimensions() const {
return Dimensions(buf_->dims,
rank == internal::kDynamicRank ? buf_->rank : rank);
}

ABSL_ATTRIBUTE_ALWAYS_INLINE size_t size_bytes() const {
if constexpr (primitive_util::IsArrayType(dtype)) {
return primitive_util::ByteWidth(dtype) * element_count();
}
return 0;
}

ABSL_ATTRIBUTE_ALWAYS_INLINE size_t element_count() const {
return absl::c_accumulate(dimensions(), int64_t{1}, std::multiplies<>());
}

se::DeviceMemory<internal::NativeType<dtype>> device_memory() const {
return se::DeviceMemory<internal::NativeType<dtype>>(
se::DeviceMemoryBase(untyped_data(), size_bytes()));
}

private:
const XLA_FFI_Buffer* buf_;
};

// clang-format off
Expand All @@ -127,7 +163,7 @@ template <PrimitiveType dtype> using BufferR3 = Buffer<dtype, 3>;
template <PrimitiveType dtype> using BufferR4 = Buffer<dtype, 4>;
// clang-format on

using Token = BufferR0<PrimitiveType::TOKEN>;
using Token = BufferR0<PrimitiveType::TOKEN>; // NOLINT

namespace internal {

Expand All @@ -148,20 +184,7 @@ ABSL_ATTRIBUTE_ALWAYS_INLINE std::optional<Buffer<dtype, rank>> DecodeBuffer(
}
}

size_t size_bytes = 0;
if constexpr (primitive_util::IsArrayType(dtype)) {
size_bytes = primitive_util::ByteWidth(dtype);
for (int64_t i = 0, r = rank == internal::kDynamicRank ? buf->rank : rank;
i < r; ++i) {
size_bytes *= buf->dims[i];
}
}

Buffer<dtype, rank> buffer;
buffer.data = se::DeviceMemory<NativeType<dtype>>(
se::DeviceMemoryBase(buf->data, size_bytes));
buffer.dimensions = absl::MakeConstSpan(buf->dims, buf->rank);
return buffer;
return Buffer<dtype, rank>(buf);
}

} // namespace internal
Expand Down
30 changes: 15 additions & 15 deletions third_party/xla/xla/ffi/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,9 @@ TEST(FfiTest, TypedAndRankedBufferArgument) {
auto call_frame = builder.Build();

auto fn = [&](BufferR2<PrimitiveType::F32> buffer) {
EXPECT_EQ(buffer.data.opaque(), storage.data());
EXPECT_EQ(buffer.data.ElementCount(), storage.size());
EXPECT_EQ(buffer.dimensions.size(), 2);
EXPECT_EQ(buffer.untyped_data(), storage.data());
EXPECT_EQ(buffer.element_count(), storage.size());
EXPECT_EQ(buffer.dimensions().size(), 2);
return absl::OkStatus();
};

Expand All @@ -554,8 +554,8 @@ TEST(FfiTest, ComplexBufferArgument) {
auto call_frame = builder.Build();

auto fn = [&](BufferR2<PrimitiveType::C64> buffer) {
EXPECT_EQ(buffer.data.opaque(), storage.data());
EXPECT_EQ(buffer.dimensions.size(), 2);
EXPECT_EQ(buffer.untyped_data(), storage.data());
EXPECT_EQ(buffer.dimensions().size(), 2);
return absl::OkStatus();
};

Expand All @@ -571,8 +571,8 @@ TEST(FfiTest, TokenArgument) {
auto call_frame = builder.Build();

auto fn = [&](Token tok) {
EXPECT_EQ(tok.data.opaque(), nullptr);
EXPECT_EQ(tok.dimensions.size(), 0);
EXPECT_EQ(tok.untyped_data(), nullptr);
EXPECT_EQ(tok.dimensions().size(), 0);
return absl::OkStatus();
};

Expand Down Expand Up @@ -720,21 +720,21 @@ TEST(FfiTest, UpdateBufferArgumentsAndResults) {
// `fn0` expects argument to be `memory0` and result to be `memory1`.
auto fn0 = [&](BufferR2<PrimitiveType::F32> arg,
Result<BufferR2<PrimitiveType::F32>> ret, int32_t n) {
EXPECT_EQ(arg.data.opaque(), storage0.data());
EXPECT_EQ(ret->data.opaque(), storage1.data());
EXPECT_EQ(arg.dimensions, dims);
EXPECT_EQ(ret->dimensions, dims);
EXPECT_EQ(arg.untyped_data(), storage0.data());
EXPECT_EQ(ret->untyped_data(), storage1.data());
EXPECT_EQ(arg.dimensions(), dims);
EXPECT_EQ(ret->dimensions(), dims);
EXPECT_EQ(n, 42);
return absl::OkStatus();
};

// `fn1` expects argument to be `memory1` and result to be `memory0`.
auto fn1 = [&](BufferR2<PrimitiveType::F32> arg,
Result<BufferR2<PrimitiveType::F32>> ret, int32_t n) {
EXPECT_EQ(arg.data.opaque(), storage1.data());
EXPECT_EQ(ret->data.opaque(), storage0.data());
EXPECT_EQ(arg.dimensions, dims);
EXPECT_EQ(ret->dimensions, dims);
EXPECT_EQ(arg.untyped_data(), storage1.data());
EXPECT_EQ(ret->untyped_data(), storage0.data());
EXPECT_EQ(arg.dimensions(), dims);
EXPECT_EQ(ret->dimensions(), dims);
EXPECT_EQ(n, 42);
return absl::OkStatus();
};
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/pjrt/cpu/cpu_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,8 @@ struct MemsetValue {
static absl::Status MemsetFromValue(
ffi::Result<ffi::BufferR1<PrimitiveType::F32>> result,
MemsetValue* memset_value) {
for (size_t i = 0; i < result->dimensions.at(0); ++i) {
result->data.base()[i] = memset_value->value;
for (size_t i = 0; i < result->element_count(); ++i) {
result->typed_data()[i] = memset_value->value;
}
return absl::OkStatus();
}
Expand Down
5 changes: 3 additions & 2 deletions third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/stream.h"
#include "xla/test.h"
#include "xla/tests/literal_test_util.h"
Expand Down Expand Up @@ -318,8 +319,8 @@ static absl::Status MemsetFromValue(
uint32_t pattern;
std::memcpy(&pattern, &memset_value->value, sizeof(pattern));

se::DeviceMemoryBase base = result->data;
return stream->Memset32(&base, pattern, result->data.size());
se::DeviceMemoryBase base = result->device_memory();
return stream->Memset32(&base, pattern, base.size());
}

XLA_FFI_DEFINE_HANDLER(kMemsetFromValue, MemsetFromValue,
Expand Down
Loading

0 comments on commit 4aea788

Please sign in to comment.