You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
python3.8/site-packages/jax/experimental/jet.py in process_primitive(self, primitive, tracers, params)
125 if t is zero_term else t for t in series]
126 for x, series in zip(primals_in, series_in)]
--> 127 rule = jet_rules[primitive]
128 primal_out, terms_out = rule(primals_in, series_in, **params)
129 if not primitive.multiple_results:
KeyError: scatter-add
@mattjj
Hi all,
Composing the jet + jacfwd + grad sometimes produces an error:
Any insights?
Thanks!
The text was updated successfully, but these errors were encountered: