[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge commit for internal changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Maciek Chociej committed May 26, 2017
2 parents 072355e + 527a8f2 commit ec7a38e
Show file tree
Hide file tree
Showing 192 changed files with 6,922 additions and 4,000 deletions.
1 change: 1 addition & 0 deletions tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ filegroup(
"//tensorflow/tensorboard/demo:all_files",
"//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files",
"//tensorflow/tensorboard/plugins:all_files",
"//tensorflow/tensorboard/plugins/images:all_files",
"//tensorflow/tensorboard/plugins/projector:all_files",
"//tensorflow/tensorboard/plugins/scalars:all_files",
"//tensorflow/tensorboard/plugins/text:all_files",
Expand Down
2 changes: 0 additions & 2 deletions tensorflow/compiler/jit/xla_cpu_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ limitations under the License.

namespace tensorflow {

const char* const DEVICE_XLA_CPU = "XLA_CPU";

class XlaCpuDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
Expand Down
2 changes: 0 additions & 2 deletions tensorflow/compiler/jit/xla_gpu_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ limitations under the License.

namespace tensorflow {

const char* const DEVICE_XLA_GPU = "XLA_GPU";

class XlaGpuDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/tf2xla/xla_op_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ namespace tensorflow {

const char* const DEVICE_CPU_XLA_JIT = "XLA_CPU_JIT";
const char* const DEVICE_GPU_XLA_JIT = "XLA_GPU_JIT";
const char* const DEVICE_XLA_CPU = "XLA_CPU";
const char* const DEVICE_XLA_GPU = "XLA_GPU";

// Is platform 'id' supported by XLA?
static bool IsPlatformSupported(perftools::gputools::Platform::Id id) {
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/tf2xla/xla_op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ namespace tensorflow {
extern const char* const DEVICE_CPU_XLA_JIT; // "CPU_XLA_JIT"
extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT"

extern const char* const DEVICE_XLA_CPU;
extern const char* const DEVICE_XLA_GPU;

constexpr std::array<DataType, 2> kIntTypes = {{DT_INT32, DT_INT64}};
constexpr std::array<DataType, 2> kFloatTypes = {{DT_FLOAT, DT_DOUBLE}};
constexpr std::array<DataType, 4> kNumericTypes = {
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/literal_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,7 @@ template <typename NativeT>
ShapeUtil::ForEachIndex(shape, stride_config.base, stride_config.dimensions,
stride_config.step, init_function);
} else {
// For scalars.
data.at(0) = generator({});
}
return Status::OK();
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/literal_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,9 @@ TEST_F(LiteralUtilTest, Populate) {
std::vector<int64> layout;
} populate_data[] = {
{{}, {}},
{{0}, {0}},
{{16}, {0}},
{{2, 0}, {1, 0}},
{{4, 16}, {1, 0}},
{{21, 12}, {0, 1}},
{{6, 11, 17}, {2, 0, 1}},
Expand Down
29 changes: 28 additions & 1 deletion tensorflow/compiler/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,12 @@ cc_test(
":hlo_evaluator",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
],
Expand Down Expand Up @@ -346,6 +348,7 @@ cc_library(
":device_memory_allocator",
":executable",
":execution_tracker",
":gpu_transfer_manager",
":hlo",
":hlo_cost_analysis",
":hlo_execution_profile",
Expand Down Expand Up @@ -439,7 +442,7 @@ cc_library(
cc_library(
name = "gpu_plugin",
deps = [
":generic_transfer_manager",
":gpu_transfer_manager",
":service",
"//tensorflow/compiler/xla/service/gpu:gpu_compiler",
"//tensorflow/core:stream_executor_no_cuda",
Expand Down Expand Up @@ -971,6 +974,27 @@ cc_library(
alwayslink = True, # Contains per-platform transfer manager registration
)

cc_library(
name = "gpu_transfer_manager",
srcs = ["gpu_transfer_manager.cc"],
hdrs = ["gpu_transfer_manager.h"],
deps = [
":generic_transfer_manager",
":transfer_manager",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/gpu:infeed_manager",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
],
alwayslink = True, # Contains per-platform transfer manager registration
)

cc_test(
name = "transfer_manager_test",
srcs = ["transfer_manager_test.cc"],
Expand Down Expand Up @@ -1423,7 +1447,9 @@ cc_library(
hdrs = ["hlo_constant_folding.h"],
deps = [
":hlo",
":hlo_evaluator",
":hlo_pass",
":hlo_query",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
Expand Down Expand Up @@ -1499,6 +1525,7 @@ cc_library(
":computation_layout",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/core:lib",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/xla/service/compile_only_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ CompileOnlyService::CompileAheadOfTime(

TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
computation_tracker_.BuildHloModule(
versioned_handle, &hlo_module_config,
versioned_handle, hlo_module_config,
/*include_unreachable_instructions=*/true));
hlo_modules.push_back(std::move(hlo_module));
}
Expand Down
9 changes: 2 additions & 7 deletions tensorflow/compiler/xla/service/computation_tracker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ void ComputationTracker::ComputeComputationPostOrder(

StatusOr<std::unique_ptr<HloModule>> ComputationTracker::BuildHloModule(
const VersionedComputationHandle& entry_handle,
const HloModuleConfig* config,
const HloModuleConfig& config,
bool include_unreachable_instructions) const {
tensorflow::mutex_lock lock(computation_mutex_);

Expand Down Expand Up @@ -209,12 +209,7 @@ StatusOr<std::unique_ptr<HloModule>> ComputationTracker::BuildHloModule(

string module_name =
tensorflow::strings::StrCat(entry_computation->name(), "_module");
std::unique_ptr<HloModule> module;
if (config == nullptr) {
module = MakeUnique<HloModule>(module_name, entry_handle);
} else {
module = MakeUnique<HloModule>(module_name, entry_handle, *config);
}
auto module = MakeUnique<HloModule>(module_name, entry_handle, config);
for (auto versioned_handle : post_order) {
UserComputation* computation =
ResolveInternal(versioned_handle.handle).ValueOrDie();
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/xla/service/computation_tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ class ComputationTracker {
// module will include the entry computation as well as all computations which
// are called directly or indirectly from the entry computation via operations
// like "map". config is the HLO module configuration to use for the
// constructed module; pass nullptr for "no configuration".
// constructed module.
// If include_unreachable_instructions is true, then instructions
// which are not reachable from the root are lowered into HloInstructions
// including unreachable parameters. This ensures the entry HloComputation has
// the same program shape (ProgramShape) as the entry UserComputation.
StatusOr<std::unique_ptr<HloModule>> BuildHloModule(
const VersionedComputationHandle& entry_handle,
const HloModuleConfig* config,
const HloModuleConfig& config,
bool include_unreachable_instructions = true) const;

string ToString() const;
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, HloDumper dump_hlo) {
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
pipeline.AddPass<CpuInstructionFusion>();
pipeline.AddPass<CpuLayoutAssignment>(
module->mutable_config()->mutable_entry_computation_layout());
module->mutable_entry_computation_layout());
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/xla/service/cpu_transfer_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor,

} // namespace xla

static xla::TransferManager* CreateCpuTransferManager() {
return new xla::CpuTransferManager();
static std::unique_ptr<xla::TransferManager> CreateCpuTransferManager() {
return xla::MakeUnique<xla::CpuTransferManager>();
}

static bool InitModule() {
Expand Down
7 changes: 2 additions & 5 deletions tensorflow/compiler/xla/service/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,9 @@ class Executable {

const HloModule& module() const { return *hlo_module_; }

const HloModuleConfig& module_config() const { return hlo_module_->config(); }
const bool has_module() const { return hlo_module_ != nullptr; }

// Returns whether this executable has an associated HloModuleConfig.
bool has_module_config() const {
return hlo_module_ != nullptr && hlo_module_->has_config();
}
const HloModuleConfig& module_config() const { return hlo_module_->config(); }

// Returns the versioned computation handle of the computation computed by
// this executable.
Expand Down
11 changes: 0 additions & 11 deletions tensorflow/compiler/xla/service/generic_transfer_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,3 @@ int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) {
}

} // namespace xla

static xla::TransferManager* CreateGenericTransferManager() {
return new xla::GenericTransferManager(se::cuda::kCudaPlatformId);
}

static bool InitModule() {
xla::TransferManager::RegisterTransferManager(se::cuda::kCudaPlatformId,
CreateGenericTransferManager);
return true;
}
static bool module_initialized = InitModule();
15 changes: 15 additions & 0 deletions tensorflow/compiler/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ cc_library(
"for_thunk.cc",
"gemm_thunk.cc",
"gpu_executable.cc",
"infeed_thunk.cc",
"kernel_thunk.cc",
"sequential_thunk.cc",
"thunk_schedule.cc",
Expand All @@ -231,6 +232,7 @@ cc_library(
"for_thunk.h",
"gemm_thunk.h",
"gpu_executable.h",
"infeed_thunk.h",
"kernel_thunk.h",
"sequential_thunk.h",
"thunk.h",
Expand All @@ -240,6 +242,7 @@ cc_library(
],
deps = [
":buffer_allocations",
":infeed_manager",
":partition_assignment",
":stream_assignment",
"//tensorflow/compiler/xla:array2d",
Expand Down Expand Up @@ -450,6 +453,18 @@ cc_library(
alwayslink = True, # Contains compiler registration
)

cc_library(
name = "infeed_manager",
srcs = ["infeed_manager.cc"],
hdrs = ["infeed_manager.h"],
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
],
)

cc_library(
name = "layout_assignment",
srcs = ["layout_assignment.cc"],
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ tensorflow::Status PrepareHloModuleForIrEmitting(
pipeline.AddInvariantChecker<HloVerifier>();
pipeline.AddPass<PadInsertion>();
pipeline.AddPass<GpuLayoutAssignment>(
hlo_module->mutable_config()->mutable_entry_computation_layout());
hlo_module->mutable_entry_computation_layout());
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
Expand Down
92 changes: 92 additions & 0 deletions tensorflow/compiler/xla/service/gpu/infeed_manager.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/core/platform/logging.h"

namespace se = ::perftools::gputools;

namespace xla {
namespace gpu {

InfeedManager::InfeedManager()
: current_buffer_(nullptr),
host_to_device_executor_(nullptr) {}

void InfeedManager::Reset() {
tensorflow::mutex_lock l(mu_);
CHECK(!current_buffer_);
for (auto buffer : enqueued_buffer_) {
buffer->Done();
}
enqueued_buffer_.clear();
}

void InfeedManager::EnqueueBuffer(InfeedBuffer* buffer) {
tensorflow::mutex_lock l(mu_);
bool was_empty = enqueued_buffer_.empty();
enqueued_buffer_.push_back(buffer);
if (was_empty) {
// This has the potential to suffer from the notified thread
// immediately trying and failing to acquire mu_, but seems
// preferable to the alternative of notifying outside the lock
// on every enqueue.
cv_.notify_one();
}
}

InfeedBuffer* InfeedManager::BlockingDequeueBuffer() {
tensorflow::mutex_lock l(mu_);
while (enqueued_buffer_.empty()) {
cv_.wait(l);
}
CHECK(!current_buffer_);
current_buffer_ = enqueued_buffer_.front();
enqueued_buffer_.pop_front();
return current_buffer_;
}

void InfeedManager::ReleaseCurrentBuffer(se::DeviceMemoryBase* device_memory) {
tensorflow::mutex_lock l(mu_);
CHECK(current_buffer_);
CHECK(device_memory->IsSameAs(*current_buffer_->device_memory()));
current_buffer_->Done();
current_buffer_ = nullptr;
}

se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) {
if (host_to_device_executor_ == nullptr) {
host_to_device_executor_ = executor;
host_to_device_stream_ = MakeUnique<se::Stream>(executor);
host_to_device_stream_->Init();
}

if (executor != host_to_device_executor_) {
// The requested executor must be the same as the one for which
// the stream is cached.
return nullptr;
}

return host_to_device_stream_.get();
}

InfeedManager* GetOrCreateInfeedManager() {
static InfeedManager* manager = new InfeedManager;
return manager;
}

} // namespace gpu
} // namespace xla
Loading

0 comments on commit ec7a38e

Please sign in to comment.