[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add top_k argument to post-process of conditional/deformable-DETR #22787

Conversation

CreatlV
Copy link
Contributor
@CreatlV CreatlV commented Apr 15, 2023

What does this PR do?

The current post-processing for object detection methods of deformable and conditional DETR assumes the number of classes * the number of object queries > 100. This reflects the original code in the deformable-DETR repository. However, this limits the flexibility of training on datasets with fewer classes/object queries. This PR suggests updating the post process for object detection code not to break if n_classes * n_object_queries < 100.

This PR suggests adding top_k argument to post-process functions of conditional/deformable-DETR with the default value of the previously hard-coded value.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts, as you added these models do you think this approach is a reasonable addition?

@HuggingFaceDocBuilderDev
Copy link
HuggingFaceDocBuilderDev commented Apr 15, 2023

The documentation is not available anymore as the PR was closed or merged.

@amyeroberts
Copy link
Collaborator

@CreatlV Thanks for opening this PR!

Having the processor be compatible with the model is definitely something we want and updating the code to make it less brittle is a great initiative. At the moment, with min, if num_queries *num_classes < 100, then the model will return all of the boxes. I think we could adapt this further to make it scale k according to the model. Specifically, adding an argument e.g. k to the method, which defaults to num_queries or its current default. We can keep the min upper bound to keep it safe.

@NielsRogge For the number of boxes returned, the default k value for this model is 100 (rather than 300). Was there a reason for setting it to this value? (I'm guessing consistency with other detr models?)

@NielsRogge
Copy link
Contributor
NielsRogge commented Apr 24, 2023

@amyeroberts it was done to reflect the original code, as linked in his message. The probabilities get reshaped to (batch_size, num_queries*num_labels) and then the top 100 values (highest scoring queries) are taken for each example in the batch. However, since Deformable DETR uses 300 queries by default, this will always be > 100. But when you train the model from scratch with a custom number of queries, this would indeed raise an error.

Making this more general makes sense. Note that we typically filter them based on a certain threshold; we first filter the 300 queries to get the top 100 recognized objects, and then set a threshold like 0.9 to only get the predictions with a score higher than 0.9. Both the threshold and the top_k value can both be seen as postprocessing hyperparameters. However I'm not sure top_k is general enough as it seems DETR-specific

@CreatlV
Copy link
Contributor Author
CreatlV commented Apr 27, 2023

I added top_k as an argument to the post-processing functions of conditional/deformable-DETR that used them. With the default value unchanged from previously. The top_k value for post_process of conditional DETR is 300, compared to 100 of the other functions, is this intentional @NielsRogge ?

@CreatlV CreatlV marked this pull request as ready for review April 27, 2023 10:18
@CreatlV CreatlV changed the title [WIP] Fix post-process hardcoding in conditional/deformable-DETR Fix post-process hardcoding in conditional/deformable-DETR Apr 27, 2023
@CreatlV CreatlV changed the title Fix post-process hardcoding in conditional/deformable-DETR Add top_k argument to post-process of conditional/deformable-DETR Apr 27, 2023
@CreatlV CreatlV requested a review from amyeroberts May 7, 2023 19:29
Copy link
Collaborator
@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this and iterating!

Just one small comment on reshaping probs. After that we're good to merge :)

@CreatlV CreatlV force-pushed the fix/deformable-detr-post-processing-hard-coded-value branch from 87420b1 to b9ef81b Compare May 9, 2023 14:55
@amyeroberts
Copy link
Collaborator

Thanks again for adding this improvement and iterating! 🎉

@amyeroberts amyeroberts merged commit b92abfa into huggingface:main May 11, 2023
sheonhan pushed a commit to sheonhan/transformers that referenced this pull request May 15, 2023
…uggingface#22787)

* update min k_value of conditional detr post-processing

* feat: add top_k arg to post processing of deformable and conditional detr

* refactor: revert changes to deprecated methods

* refactor: move prob reshape to improve code clarity and reduce repetition
gojiteji pushed a commit to gojiteji/transformers that referenced this pull request Jun 5, 2023
…uggingface#22787)

* update min k_value of conditional detr post-processing

* feat: add top_k arg to post processing of deformable and conditional detr

* refactor: revert changes to deprecated methods

* refactor: move prob reshape to improve code clarity and reduce repetition
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…uggingface#22787)

* update min k_value of conditional detr post-processing

* feat: add top_k arg to post processing of deformable and conditional detr

* refactor: revert changes to deprecated methods

* refactor: move prob reshape to improve code clarity and reduce repetition
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants