[go: nahoru, domu]

Skip to content

Commit

Permalink
Fix GetOpList and GetPythonWrappers SWIG wrappers for Python 3.
Browse files Browse the repository at this point in the history
- Return the result from GetOpList as uninterpreted bytes object.
- Write a input typemap for GetPythonWrappers to receive python 'bytes'
object and convert to const char* pointer and length.
Change: 117258253
  • Loading branch information
keveman authored and tensorflower-gardener committed Mar 15, 2016
1 parent 725e968 commit b27022d
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 11 deletions.
6 changes: 1 addition & 5 deletions tensorflow/python/client/tf_session.i
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,7 @@ tensorflow::ImportNumpy();
// is not expected to be NULL-terminated, and TF_Buffer.length does not count
// the terminator.
%typemap(out) TF_Buffer (TF_GetOpList,TF_GetBuffer) {
%#if PY_MAJOR_VERSION < 3
$result = PyString_FromStringAndSize(
%#else
$result = PyUnicode_FromStringAndSize(
%#endif
$result = PyBytes_FromStringAndSize(
reinterpret_cast<const char*>($1.data), $1.length);
}

Expand Down
4 changes: 1 addition & 3 deletions tensorflow/python/framework/load_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def load_op_library(library_filename):
Pass "library_filename" to a platform-specific mechanism for dynamically
loading a library. The rules for determining the exact location of the
library are platform-specific and are not documented here.
Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be
defined in the library.
Args:
library_filename: Path to the plugin.
Expand Down Expand Up @@ -78,7 +76,7 @@ def load_op_library(library_filename):
op_list_str = py_tf.TF_GetOpList(lib_handle)
op_list = op_def_pb2.OpList()
op_list.ParseFromString(compat.as_bytes(op_list_str))
wrappers = py_tf.GetPythonWrappers(op_list_str, len(op_list_str))
wrappers = py_tf.GetPythonWrappers(op_list_str)

# Get a unique name for the module.
module_name = hashlib.md5(wrappers).hexdigest()
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/python/framework/python_op_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,8 @@ string GetAllPythonOps(const char* hidden, bool require_shapes) {
return GetPythonOps(ops, hidden, require_shapes);
}

string GetPythonWrappers(const char* buf, size_t len) {
string op_list_str(buf, len);
string GetPythonWrappers(const char* op_wrapper_buf, size_t op_wrapper_len) {
string op_list_str(op_wrapper_buf, op_wrapper_len);
OpList ops;
ops.ParseFromString(op_list_str);
return GetPythonOps(ops, "", false);
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/framework/python_op_gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ string GetPythonOps(const OpList& ops, const string& hidden_ops,
// Get the python wrappers for a list of ops in a OpList.
// buf should be a pointer to a buffer containing the binary encoded OpList
// proto, and len should be the length of that buffer.
string GetPythonWrappers(const char* buf, size_t len);
string GetPythonWrappers(const char* op_wrapper_buf, size_t op_wrapper_len);

} // namespace tensorflow

Expand Down
17 changes: 17 additions & 0 deletions tensorflow/python/framework/python_op_gen.i
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,23 @@ limitations under the License.
#include "tensorflow/python/framework/python_op_gen.h"
%}

// Input typemap for GetPythonWrappers.
// Accepts a python object of 'bytes' type, and converts it to
// a const char* pointer and size_t length. The default typemap
// going from python bytes to const char* tries to decode the
// contents from utf-8 to unicode for Python version >= 3, but
// we want the bytes to be uninterpreted.
%typemap(in) (const char* op_wrapper_buf, size_t op_wrapper_len) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
SWIG_fail;
}
$1 = c_string;
$2 = static_cast<size_t>(py_size);
}


%ignoreall;
%unignore tensorflow::GetPythonWrappers;
%include "tensorflow/python/framework/python_op_gen.h"

0 comments on commit b27022d

Please sign in to comment.