[go: nahoru, domu]

Skip to content

Commit

Permalink
Install Python tensorflow package
Browse files Browse the repository at this point in the history
  • Loading branch information
malmaud committed Aug 8, 2016
1 parent 9647e1f commit 43e19f4
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
deps/bazel-out
deps/downloads
deps/miniconda2
src/scratch.jl
2 changes: 2 additions & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ Requests
BinDeps
ProtoBuf
PyCall
Distributions
Conda
4 changes: 4 additions & 0 deletions deps/build.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Requests
import Conda

base = dirname(@__FILE__)
download_dir = joinpath(base, "downloads")
Expand Down Expand Up @@ -28,3 +29,6 @@ end
mv("libtensorflow.so", "$base/bazel-out/local_linux-fastbuild/bin/tensorflow/libtensorflow.so", remove_destination=true)
mv("libc_api.so", "$base/bazel-out/local_linux-fastbuild/bin/tensorflow/c/libc_api.so", remove_destination=true)
end

Conda.install("numpy")
run(`$(Conda.PYTHONDIR)/pip install tensorflow-0.10.0rc0-py2-none-any.whl`)
25 changes: 25 additions & 0 deletions examples/logistic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using Distributions

x = randn(100, 3)

w = randn(3, 10)

y = x*w

sess = Session()

X = placeholder(Float64)

Y_obs = placeholder(Float64)

W = Variable(randn(3, 10))

Y = X*W

Loss = reduce_sum((Y-Y_obs)^2)
run(sess, initialize_all_variables())
run(sess, Loss, Dict(X=>x, Y_obs=>y))

grad = gradients(Loss, W)

run(sess, grad, Dict(X=>x, Y_obs=>y))
21 changes: 18 additions & 3 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,13 @@ function Node(node_def::tensorflow.NodeDef)
if dtype == tensorflow._DataType.DT_FLOAT
desc["value"] = Tensor(attr.tensor.float_val[1])
elseif dtype == tensorflow._DataType.DT_INT32
desc["value"] = Tensor(attr.tensor.int_val[1])
if length(attr.tensor.int_val) == 0
desc["value"] = Tensor(Int32[])
else
desc["value"] = Tensor(attr.tensor.int_val[1])
end
elseif dtype == tensorflow._DataType.DT_DOUBLE
desc["value"] = Tensor(attr.tensor.double_val[1])
end
elseif attr_name == "keep_dims"
desc["keep_dims"] = attr.b
Expand All @@ -285,7 +291,7 @@ function Node(node_def::tensorflow.NodeDef)
Node(desc)
end

node_name(node::Node) = ccall((:TF_NodeName), Cstring, (Ptr{Void},), node.ptr) |> unsafe_string
node_name(node::AbstractNode) = ccall((:TF_NodeName), Cstring, (Ptr{Void},), Node(node).ptr) |> unsafe_string

function get_attr_value_proto(node::Node, attr_name)
buf = Buffer()
Expand All @@ -305,7 +311,16 @@ Base.getindex(node::Node, attr_name) = get_attr_value_proto(node, attr_name)

function Base.eltype(node::AbstractNode)
node = Node(node)
dtype = node["dtype"]._type
dtype = nothing
try
dtype = node["dtype"]._type
catch
try
dtype = node["T"]._type
catch
error("eltype called on node with no type information")
end
end
dt = tensorflow._DataType
type_map = Dict(dt.DT_FLOAT=>Float32, dt.DT_INT32=>Int32, dt.DT_DOUBLE=>Float64, dt.DT_INT64=>Int64)
return type_map[dtype]
Expand Down
2 changes: 1 addition & 1 deletion src/nn.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module NN
module nn

import ..TensorFlow: Node, NodeDescription, get_def_graph, capitalize, add_input, Port, get_name

Expand Down
1 change: 1 addition & 0 deletions src/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ end

*(x::Number, n::AbstractNode) = x.*n # For supporting notation like `2x`
^(n::AbstractNode, x::Int) = invoke(^, (AbstractNode, Any), n, x)
.^(n::AbstractNode, x::Number) = n^x

for (jl_func_name, tf_func_name) in [
(:log, "Log"),
Expand Down
4 changes: 3 additions & 1 deletion src/py.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function extend_graph(graph, other::PyObject)
end


function gradients(y, x)
function gradients(y, x::AbstractArray)
py_graph = make_py_graph(get_def_graph())
to_py_node = node->py_graph[:get_tensor_by_name](string(node_name(node), ":0"))
py_x = [to_py_node(node) for node in x]
Expand All @@ -50,3 +50,5 @@ function gradients(y, x)
extend_graph(get_def_graph(), py_graph_def)
return [get_node_by_name(get_def_graph(), _[:name])|>get for _ in grad_node]
end

gradients(y, x) = gradients(y, [x])[1]

0 comments on commit 43e19f4

Please sign in to comment.