[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement needed: Composition of jet, grad and jacfwd #5365

Closed
vduarte opened this issue Jan 11, 2021 · 2 comments · Fixed by #6257
Closed

Enhancement needed: Composition of jet, grad and jacfwd #5365

vduarte opened this issue Jan 11, 2021 · 2 comments · Fixed by #6257
Assignees
Labels
enhancement New feature or request

Comments

@vduarte
Copy link
vduarte commented Jan 11, 2021

@mattjj

Hi all,

Composing the jet + jacfwd + grad sometimes produces an error:

import jax.numpy as jnp
from jax.experimental.jet import jet


def f(x):
  x0 = x[0]
  x1 = x[1]
  return (x0**5 + x1**5).sum()


def h(ε):
  from jax import jacfwd, grad

  x = jnp.array([1., 1.])
  μ = ε * x

  def F(t):
    return f(x + t * μ)

  return grad(jacfwd(F))(0.)


jet(h, (0.,), ([1.],))

Any insights?
Thanks!

@mattjj
Copy link
Member
mattjj commented Jan 12, 2021

What's the error?

@vduarte
Copy link
Author
vduarte commented Jan 12, 2021

Here:

python3.8/site-packages/jax/experimental/jet.py in process_primitive(self, primitive, tracers, params)
125 if t is zero_term else t for t in series]
126 for x, series in zip(primals_in, series_in)]
--> 127 rule = jet_rules[primitive]
128 primal_out, terms_out = rule(primals_in, series_in, **params)
129 if not primitive.multiple_results:
KeyError: scatter-add

Adding

deflinear(lax.scatter_add_p)

to jet.py seems to solve the problem.

Thanks

@mattjj mattjj added the enhancement New feature or request label Jan 12, 2021
mattjj added a commit that referenced this issue Mar 29, 2021
could use a better test though...
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this issue Apr 1, 2021
could use a better test though...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants