[go: nahoru, domu]

Skip to content

Commit

Permalink
add scatter_add jet rule, fixes google#5365
Browse files Browse the repository at this point in the history
could use a better test though...
  • Loading branch information
mattjj authored and NeilGirdhar committed Apr 1, 2021
1 parent 920589e commit 355e206
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
1 change: 1 addition & 0 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4555,6 +4555,7 @@ def _scatter_add_jvp(primals, tangents, *, update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted, unique_indices):
operand, scatter_indices, updates = primals
g_operand, g_scatter_indices, g_updates = tangents
del g_scatter_indices # ignored
val_out = scatter_add_p.bind(
operand, scatter_indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
Expand Down
16 changes: 13 additions & 3 deletions jax/experimental/jet.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,11 @@ def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
return fun.call_wrapped(*tracers)


class ZeroTerm(object): pass
class ZeroTerm: pass
zero_term = ZeroTerm()
register_pytree_node(ZeroTerm, lambda z: ((), None), lambda _, xs: zero_term)

class ZeroSeries(object): pass
class ZeroSeries: pass
zero_series = ZeroSeries()
register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series)

Expand Down Expand Up @@ -549,7 +549,6 @@ def _select_taylor_rule(primal_in, series_in, **params):
return primal_out, series_out
jet_rules[lax.select_p] = _select_taylor_rule


def _lax_max_taylor_rule(primal_in, series_in):
x, y = primal_in

Expand Down Expand Up @@ -589,3 +588,14 @@ def _custom_jvp_call_jaxpr_rule(primals_in, series_in, *, fun_jaxpr,
del jvp_jaxpr_thunk
return jet(core.jaxpr_as_fun(fun_jaxpr), primals_in, series_in)
jet_rules[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_rule

def _scatter_add_rule(primals_in, series_in, *, update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted, unique_indices):
bind = partial(lax.scatter_add_p.bind, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices)
operand, scatter_indices, updates = primals_in
primal_out = bind(operand, scatter_indices, updates)
series_out = [bind(d1, scatter_indices, d2) for d1, _, d2 in zip(*series_in)]
return primal_out, series_out
jet_rules[lax.scatter_add_p] = _scatter_add_rule
20 changes: 20 additions & 0 deletions tests/jet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,26 @@ def g(eps):
return jax.grad(f)(x, eps)
jet(g, (1.,), ([1.],)) # doesn't crash

def test_scatter_add(self):
# very basic test from https://github.com/google/jax/issues/5365
def f(x):
x0 = x[0]
x1 = x[1]
return (x0**5 + x1**5).sum()

def h(eps):
from jax import jacfwd, grad

x = jnp.array([1., 1.])
μ = eps * x

def F(t):
return f(x + t * μ)

return grad(jacfwd(F))(0.)

self.check_jet(h, (0.,), ([1., 2., 3.],))


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 355e206

Please sign in to comment.