[go: nahoru, domu]

Skip to content

Commit

Permalink
Added reduction ops
Browse files Browse the repository at this point in the history
  • Loading branch information
malmaud committed Aug 4, 2016
1 parent 1824496 commit c6365ff
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
deps/bazel-out
deps/downloads
src/scratch.jl
37 changes: 21 additions & 16 deletions src/ops.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Base: log, exp, +, -, *, /
import Base: log, exp, +, -, *, /, .*, .+, ./, .-, ^, .^

const name_idx = Ref{Int}(1)

Expand Down Expand Up @@ -54,6 +54,8 @@ for (bin_op, jl_func_name, tf_func_name) in [
@eval $bin_op(n1, n2::Node) = $jl_func_name(constant(n1), n2)
end

^(n::Node, x::Int) = invoke(^, (Node, Any), n, x)

for (jl_func_name, tf_func_name) in [
(:log, "Log"),
(:exp, "Exp"),
Expand All @@ -78,21 +80,24 @@ end

# Reductions

function reduce_sum(n::Node, name="")
name = get_name(name)
range_start = constant(Int32(0))
range_delta = constant(Int32(1))
desc = NodeDescription(get_def_graph(), "Rank", "$name/rank")
rank = Node(desc)
desc = NodeDescription(get_def_graph(), "Range", "$name/range")
add_input(desc, range_start)
add_input(desc, rank)
add_input(desc, range_delta)
range = Node(desc)
desc = NodeDescription(get_def_graph(), "Sum", name)
add_input(desc, n)
add_inpnut(desc, range)
Node(desc)
for reduction in [:sum, :prod, :min, :max, :all, :any, :mean]
@eval function $(symbol("reduce_", reduction))(n::Node, name="")
name = get_name(name)
range_start = constant(Int32(0))
range_delta = constant(Int32(1))
desc = NodeDescription(get_def_graph(), "Rank", "$name/rank")
add_input(desc, n)
rank = Node(desc)
desc = NodeDescription(get_def_graph(), "Range", "$name/range")
add_input(desc, range_start)
add_input(desc, rank)
add_input(desc, range_delta)
range = Node(desc)
desc = NodeDescription(get_def_graph(), $(capitalize(reduction)), name)
add_input(desc, n)
add_input(desc, range)
Node(desc)
end
end

include("nn.jl")
6 changes: 2 additions & 4 deletions src/scratch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ run(sess, z, Dict(x=>1, y=>2))

sess=Session()

x=placeholder(Float64)
x=constant([3.0, 5.0])

y=NN.relu(x)
run(sess,NN.softmax(x),Dict(x=>[1.0 -3.0;1 5]))
constant(Int32(1))
run(sess, reduce_mean(x))

0 comments on commit c6365ff

Please sign in to comment.