You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It would be helpful to have wallclock comparions for the benchmarks you had posted. I think Danijar's jax implementation uses jax scan heavily to make the imagination/rollouts loops efficient.
The text was updated successfully, but these errors were encountered:
sai-prasanna
changed the title
Wallclock comparison for the benchmark comparisons
Wallclock comparison for the benchmarks
Nov 22, 2023
Thank you for your suggestion regarding wallclock comparisons for my benchmarks. You're absolutely right about the efficiency of Danijar's JAX implementation, particularly the use of the jax.scan function for optimizing imagination and rollout loops. This is indeed a significant factor in the performance of their implementation.
Regarding my PyTorch-based implementation, one of the key reasons it might not match the speed of the original JAX version is due to the absence of a directly comparable, efficient scan function in PyTorch.
Ofcourse implementing raw cuda kernel is one of the option to optimize the efficiency.
However, fortunatelly there is a development in the pipeline for efficient scan function and it seems available in near future as discussed in the PyTorch GitHub issue 50688.
Once this function is released, I plan to integrate it into our implementation. In the meantime, for insights into the current efficiency of the code, you might find this discussion useful.
It would be helpful to have wallclock comparions for the benchmarks you had posted. I think Danijar's jax implementation uses jax scan heavily to make the imagination/rollouts loops efficient.
The text was updated successfully, but these errors were encountered: