[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chronos: Support torch.jit acceleration AutoFormerForecaster.predict #6452

Open
wants to merge 13 commits into
base: main_backup
Choose a base branch
from

Conversation

liangs6212
Copy link
Contributor
@liangs6212 liangs6212 commented Nov 4, 2022

Description

  1. After using torch.jit, the performance of "predict" is improved by about 25%.
  2. Changed parameter order for autoformer.xxxx_step.
  3. Split AutoCorrelation.forward and trace the function using torch.jit.script.
  4. torch.jit.optimized_execution can speed up model inference.

1. Why the change?

  1. Speed up model inference.
  2. Sync parameter order of TSDataset, because torch.jit will freeze the model.
  3. AutoCorrelation.forward contains if, torch.jit.script to ensure that the model is tracked correctly and does not change the performance of the model.
  4. We need to disable some optimizations, they may bring down the speed of prediction.
    Related link: https://discuss.pytorch.org/t/speed-of-custom-rnn-is-super-slow/63209/3

2. User API changes

yhat = autoformer.predict(data, use_jit=False) # use self.internal
# Inference speed(nyc_taxi): ~800ms

yhat = autoformer.predict(data, use_jit=True) # use torch.jit
# Inference speed(nyc_taxi): ~600ms

# tips: If we don't use `torch.jit.optimized_execution`, then jit_fp32 runs about 2s for the first and second time,

4. How to test?

  • Unit test

Copy link
Collaborator
@TheaperDeng TheaperDeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

others LGTM

@TheaperDeng
Copy link
Collaborator

we could rely on forecaster's hyperparameters to decide the dummy input for jit's trace. Please generate the dummy input ourselves rather than use the first input of data.

Copy link
Collaborator
@TheaperDeng TheaperDeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

others LGTM

forecaster.fit(train_data, epochs=2)
pred = forecaster.predict(test_data)
jit_pred = forecaster.predict_with_jit(test_data)
assert pred[0].shape == jit_pred[0].shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we assert almost equal here? if not why?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants