[go: nahoru, domu]

Skip to content

Commit

Permalink
Fix dynamic_stitch gradient implementation
Browse files Browse the repository at this point in the history
Addresses tensorflow#7397

Also expanded unit tests to cover these cases.
  • Loading branch information
drasmuss committed Feb 13, 2017
1 parent 714f9b7 commit e17aaa4
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 6 deletions.
65 changes: 61 additions & 4 deletions tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
class DynamicStitchTest(test.TestCase):

def testScalar(self):
with self.test_session():
with self.test_session() as sess:
indices = [constant_op.constant(0), constant_op.constant(1)]
data = [constant_op.constant(40), constant_op.constant(60)]
for step in -1, 1:
Expand All @@ -42,8 +42,16 @@ def testScalar(self):
# length.
self.assertEqual([None], stitched_t.get_shape().as_list())

# Test gradients
stitched_grad = 7 * stitched_val
grads = gradients_impl.gradients(stitched_t, indices + data,
stitched_grad)
self.assertEqual(grads[:2], [None] * 2) # Indices have no gradients
for datum, grad in zip(data, sess.run(grads[2:])):
self.assertAllEqual(7 * datum.eval(), grad)

def testSimpleOneDimensional(self):
with self.test_session():
with self.test_session() as sess:
indices = [
constant_op.constant([0, 4, 7]), constant_op.constant([1, 6, 2, 3, 5])
]
Expand All @@ -59,8 +67,16 @@ def testSimpleOneDimensional(self):
# length.
self.assertEqual([None], stitched_t.get_shape().as_list())

# Test gradients
stitched_grad = 7 * stitched_val
grads = gradients_impl.gradients(stitched_t, indices + data,
stitched_grad)
self.assertEqual(grads[:2], [None] * 2) # Indices have no gradients
for datum, grad in zip(data, sess.run(grads[2:])):
self.assertAllEqual(7 * datum.eval(), grad)

def testOneListOneDimensional(self):
with self.test_session():
with self.test_session() as sess:
indices = [constant_op.constant([1, 6, 2, 3, 5, 0, 4, 7])]
data = [constant_op.constant([10, 60, 20, 30, 50, 0, 40, 70])]
stitched_t = data_flow_ops.dynamic_stitch(indices, data)
Expand All @@ -71,8 +87,15 @@ def testOneListOneDimensional(self):
# length.
self.assertEqual([None], stitched_t.get_shape().as_list())

# Test gradients
stitched_grad = 7 * stitched_val
grads = gradients_impl.gradients(stitched_t, indices + data,
stitched_grad)
self.assertEqual(grads[0], None) # Indices have no gradients
self.assertAllEqual(7 * data[0].eval(), sess.run(grads[1]))

def testSimpleTwoDimensional(self):
with self.test_session():
with self.test_session() as sess:
indices = [
constant_op.constant([0, 4, 7]), constant_op.constant([1, 6]),
constant_op.constant([2, 3, 5])
Expand All @@ -91,6 +114,14 @@ def testSimpleTwoDimensional(self):
# some unknown number of rows.
self.assertEqual([None, 2], stitched_t.get_shape().as_list())

# Test gradients
stitched_grad = 7 * stitched_val
grads = gradients_impl.gradients(stitched_t, indices + data,
stitched_grad)
self.assertEqual(grads[:3], [None] * 3) # Indices have no gradients
for datum, grad in zip(data, sess.run(grads[3:])):
self.assertAllEqual(7 * datum.eval(), grad)

def testHigherRank(self):
with self.test_session() as sess:
indices = [
Expand All @@ -115,6 +146,32 @@ def testHigherRank(self):
for datum, grad in zip(data, sess.run(grads[3:])):
self.assertAllEqual(7 * datum.eval(), grad)

def testDuplicates(self):
with self.test_session() as sess:
indices = [
constant_op.constant([0, 1]), constant_op.constant([1, 2, 3]),
constant_op.constant([3, 4])
]
data = [
constant_op.constant([1, 2]), constant_op.constant([12, 13, 14]),
constant_op.constant([24, 25])
]
stitched_t = data_flow_ops.dynamic_stitch(indices, data)
stitched_val = stitched_t.eval()
correct = [1, 12, 13, 24, 25]
self.assertAllEqual(correct, stitched_val)
self.assertEqual([None], stitched_t.get_shape().as_list())

# Test gradients
stitched_grad = stitched_val
grads = gradients_impl.gradients(stitched_t, indices + data,
stitched_grad)
self.assertEqual(grads[:3], [None] * 3) # Indices have no gradients
grad_vals = sess.run(grads[3:])
self.assertAllEqual(grad_vals[0], [1, 0])
self.assertAllEqual(grad_vals[1], [12, 13, 0])
self.assertAllEqual(grad_vals[2], [24, 25])

def testErrorIndicesMultiDimensional(self):
indices = [
constant_op.constant([0, 4, 7]), constant_op.constant([[1, 6, 2, 3, 5]])
Expand Down
36 changes: 34 additions & 2 deletions tensorflow/python/ops/data_flow_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@

from six.moves import xrange # pylint: disable=redefined-builtin

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops


Expand Down Expand Up @@ -54,12 +57,41 @@ def _DynamicStitchGrads(op, grad):
def AsInt32(x):
return (x if op.inputs[0].dtype == dtypes.int32 else
math_ops.cast(x, dtypes.int32))
inputs = [AsInt32(op.inputs[i]) for i in xrange(num_values)]
idxs = [AsInt32(array_ops.reshape(op.inputs[i], (-1,)))
for i in xrange(num_values)]
if isinstance(grad, ops.IndexedSlices):
output_shape = array_ops.shape(op.outputs[0])
output_rows = output_shape[0]
grad = math_ops.unsorted_segment_sum(grad.values, grad.indices, output_rows)
values_grad = [array_ops.gather(grad, inp) for inp in inputs]

values_grad = []
later_idxs = idxs[-1]
zeros = array_ops.zeros_like(grad)
for i in range(num_values - 1, -1, -1):
if i == num_values - 1:
v_grad = array_ops.gather(grad, idxs[i])
else:
def is_unique(val):
return control_flow_ops.cond(
math_ops.reduce_any(math_ops.equal(val, later_idxs)),
lambda: constant_op.constant(False),
lambda: constant_op.constant(True))
unique = functional_ops.map_fn(is_unique, idxs[i], dtypes.bool)
diff_indices = array_ops.where(unique)[:, 0]
diff_values = array_ops.gather(idxs[i], diff_indices)

later_idxs = array_ops.concat((diff_values, later_idxs), axis=0)

n_indices = idxs[i].get_shape()[0]
v_grad = data_flow_ops.dynamic_stitch(
[math_ops.range(n_indices), math_ops.cast(diff_indices, dtypes.int32)],
[zeros[:n_indices], array_ops.gather(grad, diff_values)])

v_grad = array_ops.reshape(v_grad, op.inputs[i].get_shape().concatenate(
v_grad.get_shape()[1:]))

values_grad = [v_grad] + values_grad

return indices_grad + values_grad


Expand Down

0 comments on commit e17aaa4

Please sign in to comment.