-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Chronos: Support torch.jit
acceleration AutoFormerForecaster.predict
#6452
base: main_backup
Are you sure you want to change the base?
Chronos: Support torch.jit
acceleration AutoFormerForecaster.predict
#6452
Conversation
83b6833
to
d231132
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
others LGTM
python/chronos/src/bigdl/chronos/forecaster/autoformer_forecaster.py
Outdated
Show resolved
Hide resolved
we could rely on forecaster's hyperparameters to decide the dummy input for jit's trace. Please generate the dummy input ourselves rather than use the first input of data. |
bbd6d83
to
f0b399d
Compare
092316d
to
5337d81
Compare
e188e17
to
c2749c9
Compare
9a5b156
to
df43bfb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
others LGTM
forecaster.fit(train_data, epochs=2) | ||
pred = forecaster.predict(test_data) | ||
jit_pred = forecaster.predict_with_jit(test_data) | ||
assert pred[0].shape == jit_pred[0].shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we assert almost equal here? if not why?
Description
torch.jit
, the performance of "predict" is improved by about 25%.autoformer.xxxx_step
.AutoCorrelation.forward
and trace the function usingtorch.jit.script
.torch.jit.optimized_execution
can speed up model inference.1. Why the change?
TSDataset
, becausetorch.jit
will freeze the model.AutoCorrelation.forward
contains if,torch.jit.script
to ensure that the model is tracked correctly and does not change the performance of the model.Related link: https://discuss.pytorch.org/t/speed-of-custom-rnn-is-super-slow/63209/3
2. User API changes
4. How to test?