[go: nahoru, domu]

Skip to content

Commit

Permalink
Refactor the unified C API implementation in a private header
Browse files Browse the repository at this point in the history
The idea is to prepare to split the various implementation of the API
by exposing the API through abstract base class in this private headers
with specific implementations for Graph/Eager/MLIR in different files.
The C API should be able to be implemented as a binding over the API
exposed by the abstract base classes and independent of the actual
implementations.

PiperOrigin-RevId: 308387700
Change-Id: I0610d4e9b8067a14d2b31e3c4777d8ed783987c6
  • Loading branch information
joker-eph authored and tensorflower-gardener committed Apr 25, 2020
1 parent b0b883f commit 9355c5e
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 51 deletions.
1 change: 1 addition & 0 deletions tensorflow/c/eager/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ tf_cuda_library(
srcs = [
"c_api_experimental.cc",
"c_api_unified_experimental.cc",
"c_api_unified_experimental_private.h",
],
hdrs = [
"c_api_experimental.h",
Expand Down
70 changes: 19 additions & 51 deletions tensorflow/c/eager/c_api_unified_experimental.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/c/eager/c_api_unified_experimental.h"

#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental_private.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/monitoring/counter.h"
Expand All @@ -29,43 +28,12 @@ limitations under the License.
#include "tensorflow/core/platform/strcat.h"

using tensorflow::string;
using tensorflow::internal::dynamic_cast_helper;
using tensorflow::internal::ExecutionContext;
using tensorflow::internal::unwrap;
using tensorflow::internal::wrap;

// =============================================================================
// Unified Execution APIs for Eager and tracing backends.
// =============================================================================

typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs,
TF_OutputList* o, TF_ExecutionContext* ctx,
TF_Status* s);

struct TF_ExecutionContext {
// Needed to implement our own version of RTTI since dynamic_cast is not
// supported in mobile builds.
enum ExecutionContextKind { GraphContext, EagerContext };
explicit TF_ExecutionContext(ExecutionContextKind kind) : k(kind) {}
ExecutionContextKind getKind() const { return k; }

virtual void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs,
TF_OutputList* o, TF_Status* s) = 0;
virtual TF_AbstractOp* CreateOperation() = 0;
virtual void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) = 0;
virtual ~TF_ExecutionContext() {}

private:
const ExecutionContextKind k;
};

void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }

template <typename T, typename S>
T* dynamic_cast_helper(S source) {
if (source->getKind() != T::kKind) {
return nullptr;
}
return tensorflow::down_cast<T*>(source);
}
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }

class TF_GraphContext;
class TF_EagerContext;
Expand Down Expand Up @@ -104,7 +72,7 @@ struct TF_AbstractOp {
};

TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
return c->CreateOperation();
return unwrap(c)->CreateOperation();
}

void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
Expand Down Expand Up @@ -207,9 +175,9 @@ struct TF_AbstractFunction {
~TF_AbstractFunction() { TF_DeleteFunction(func); }
};

class TF_EagerContext : public TF_ExecutionContext {
class TF_EagerContext : public ExecutionContext {
public:
TF_EagerContext() : TF_ExecutionContext(kKind) {}
TF_EagerContext() : ExecutionContext(kKind) {}

void Build(TFE_ContextOptions* options, TF_Status* status) {
eager_ctx_ = TFE_NewContext(options, status);
Expand Down Expand Up @@ -268,7 +236,7 @@ class TF_EagerContext : public TF_ExecutionContext {

~TF_EagerContext() override { TFE_DeleteContext(eager_ctx_); }

static constexpr ExecutionContextKind kKind = EagerContext;
static constexpr ExecutionContextKind kKind = kEagerContext;

private:
friend TFE_Context* TF_ExecutionContextGetTFEContext(
Expand All @@ -282,10 +250,10 @@ TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
return absl::get<TF_GraphTensor*>(t->t)->ctx;
}

class TF_GraphContext : public TF_ExecutionContext {
class TF_GraphContext : public ExecutionContext {
public:
TF_GraphContext()
: TF_ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}

TF_AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
Expand Down Expand Up @@ -363,7 +331,7 @@ class TF_GraphContext : public TF_ExecutionContext {

~TF_GraphContext() override {}

static constexpr ExecutionContextKind kKind = GraphContext;
static constexpr ExecutionContextKind kKind = kGraphContext;

private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
Expand Down Expand Up @@ -410,9 +378,9 @@ TF_ExecutionContext* TF_NewExecutionContext(TF_ExecutionContextOptions* options,
auto* ctx = new TF_EagerContext();
ctx->Build(absl::get<TF_EagerContextOptions*>(options->options)->options,
s);
return ctx;
return wrap(ctx);
} else {
return new TF_GraphContext();
return wrap(new TF_GraphContext());
}
}

Expand Down Expand Up @@ -445,14 +413,14 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
ctx->ExecuteOperation(op, num_inputs, inputs, o, s);
unwrap(ctx)->ExecuteOperation(op, num_inputs, inputs, o, s);
}

TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status) {
auto* graph_ctx = dynamic_cast_helper<const TF_GraphContext>(fn_body);
auto* graph_ctx = dynamic_cast_helper<const TF_GraphContext>(unwrap(fn_body));
if (graph_ctx == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"fn_body is not a TF_GraphContext.");
Expand All @@ -469,7 +437,7 @@ void TF_DeleteAbstractFunction(TF_AbstractFunction* func) { delete func; }
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
TF_AbstractFunction* func,
TF_Status* s) {
ctx->RegisterFunction(func, s);
unwrap(ctx)->RegisterFunction(func, s);
}

// Temporary APIs till we figure out how to create scalar valued Eager
Expand All @@ -496,5 +464,5 @@ TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
}

TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
return dynamic_cast_helper<TF_EagerContext>(ctx)->eager_ctx_;
return dynamic_cast_helper<TF_EagerContext>(unwrap(ctx))->eager_ctx_;
}
71 changes: 71 additions & 0 deletions tensorflow/c/eager/c_api_unified_experimental_private.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_PRIVATE_H_
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_PRIVATE_H_

#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/core/platform/casts.h"

namespace tensorflow {
namespace internal {

// =============================================================================
// Unified Execution APIs for Eager and tracing backends.
// =============================================================================

struct ExecutionContext {
// Needed to implement our own version of RTTI since dynamic_cast is not
// supported in mobile builds.
enum ExecutionContextKind { kGraphContext, kEagerContext };
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
ExecutionContextKind getKind() const { return k; }

virtual void ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs,
TF_OutputList* o, TF_Status* s) = 0;
virtual TF_AbstractOp* CreateOperation() = 0;
virtual void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) = 0;
virtual ~ExecutionContext() = default;

private:
const ExecutionContextKind k;
};

static inline ExecutionContext* unwrap(TF_ExecutionContext* ctx) {
return reinterpret_cast<ExecutionContext*>(ctx);
}
static inline const ExecutionContext* unwrap(const TF_ExecutionContext* ctx) {
return reinterpret_cast<const ExecutionContext*>(ctx);
}
static inline TF_ExecutionContext* wrap(ExecutionContext* ctx) {
return reinterpret_cast<TF_ExecutionContext*>(ctx);
}
static inline const TF_ExecutionContext* wrap(const ExecutionContext* ctx) {
return reinterpret_cast<const TF_ExecutionContext*>(ctx);
}

template <typename T, typename S>
T* dynamic_cast_helper(S source) {
if (source->getKind() != T::kKind) {
return nullptr;
}
return tensorflow::down_cast<T*>(source);
}

} // namespace internal
} // namespace tensorflow

#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_PRIVATE_H_

0 comments on commit 9355c5e

Please sign in to comment.