[go: nahoru, domu]

Skip to content

Commit

Permalink
Add top_k argument to post-process of conditional/deformable-DETR (h…
Browse files Browse the repository at this point in the history
…uggingface#22787)

* update min k_value of conditional detr post-processing

* feat: add top_k arg to post processing of deformable and conditional detr

* refactor: revert changes to deprecated methods

* refactor: move prob reshape to improve code clarity and reduce repetition
  • Loading branch information
CreatlV authored and gojiteji committed Jun 5, 2023
1 parent 8735cd9 commit e0a87e6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,7 @@ def post_process(self, outputs, target_sizes):

# Copied from transformers.models.deformable_detr.image_processing_deformable_detr.DeformableDetrImageProcessor.post_process_object_detection with DeformableDetr->ConditionalDetr
def post_process_object_detection(
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None, top_k: int = 100
):
"""
Converts the raw output of [`ConditionalDetrForObjectDetection`] into final bounding boxes in (top_left_x,
Expand All @@ -1342,6 +1342,8 @@ def post_process_object_detection(
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
(height, width) of each image in the batch. If left to None, predictions will not be resized.
top_k (`int`, *optional*, defaults to 100):
Keep only top k bounding boxes before filtering by thresholding.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
Expand All @@ -1356,7 +1358,9 @@ def post_process_object_detection(
)

prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
prob = prob.view(out_logits.shape[0], -1)
k_value = min(top_k, prob.size(1))
topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
scores = topk_values
topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
labels = topk_indexes % out_logits.shape[2]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,7 @@ def post_process(self, outputs, target_sizes):
return results

def post_process_object_detection(
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None
self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None, top_k: int = 100
):
"""
Converts the raw output of [`DeformableDetrForObjectDetection`] into final bounding boxes in (top_left_x,
Expand All @@ -1339,6 +1339,8 @@ def post_process_object_detection(
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
(height, width) of each image in the batch. If left to None, predictions will not be resized.
top_k (`int`, *optional*, defaults to 100):
Keep only top k bounding boxes before filtering by thresholding.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
Expand All @@ -1353,7 +1355,9 @@ def post_process_object_detection(
)

prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
prob = prob.view(out_logits.shape[0], -1)
k_value = min(top_k, prob.size(1))
topk_values, topk_indexes = torch.topk(prob, k_value, dim=1)
scores = topk_values
topk_boxes = torch.div(topk_indexes, out_logits.shape[2], rounding_mode="floor")
labels = topk_indexes % out_logits.shape[2]
Expand Down

0 comments on commit e0a87e6

Please sign in to comment.