[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/extend-detections-with-data-field-to-support-custom-payload #700

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f3f545d
simplifying `Detections.__eq__` method
SkalskiP Dec 28, 2023
f2031d5
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Dec 28, 2023
ea6439b
`Detections.__eq__` updated and tested
SkalskiP Dec 28, 2023
b2d98fe
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Dec 28, 2023
efed23f
`Detections.merge` updated and tested
SkalskiP Dec 29, 2023
27eb823
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Dec 29, 2023
701b0c0
fix `test_detections_from_xml_obj`
SkalskiP Dec 29, 2023
5dfc2e6
Merge remote-tracking branch 'origin/feature/extend-detections-with-d…
SkalskiP Dec 29, 2023
14761cb
fix `test_detections_from_xml_obj`
SkalskiP Dec 29, 2023
63ccc98
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Dec 29, 2023
c241237
Merge remote-tracking branch 'origin/feature/extend-detections-with-d…
SkalskiP Dec 29, 2023
9be71da
handle one more edge case in `Detections.merge` logic
SkalskiP Dec 29, 2023
db245bb
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Dec 29, 2023
95e2ad6
refactor of `core.py`; moving utils to `utils.py`
SkalskiP Dec 29, 2023
955220f
Merge remote-tracking branch 'origin/feature/extend-detections-with-d…
SkalskiP Dec 29, 2023
ea21a75
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Dec 29, 2023
099d360
better `Detections.merge` docs + more `Detections.merge` tests
SkalskiP Dec 29, 2023
5984d33
better `Detections.merge` validation
SkalskiP Dec 29, 2023
e75e527
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Dec 29, 2023
ced9f35
initial `Detections.__getitem__` implementation
SkalskiP Jan 2, 2024
5169481
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Jan 2, 2024
a10b984
initial `Detections.__iter__` implementation
SkalskiP Jan 2, 2024
8e40fa8
Merge remote-tracking branch 'origin/feature/extend-detections-with-d…
SkalskiP Jan 2, 2024
0f63966
initial `Detections.__setitem__` implementation + `Detections.__getit…
SkalskiP Jan 2, 2024
e3d92a0
ready for colab tests
SkalskiP Jan 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor of core.py; moving utils to utils.py
  • Loading branch information
SkalskiP committed Dec 29, 2023
commit 95e2ad6d5225de927fa44d760549108d59735a6d
131 changes: 8 additions & 123 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from dataclasses import astuple, dataclass, field
from itertools import chain
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import numpy as np
Expand All @@ -11,127 +10,12 @@
extract_ultralytics_masks,
non_max_suppression,
process_roboflow_result,
xywh_to_xyxy,
xywh_to_xyxy, validate_xyxy, validate_mask, validate_class_id, validate_confidence,
validate_tracker_id, validate_data, is_data_equal, merge_data,
)
from supervision.geometry.core import Position


def _validate_xyxy(xyxy: Any, n: int) -> None:
is_valid = isinstance(xyxy, np.ndarray) and xyxy.shape == (n, 4)
if not is_valid:
raise ValueError("xyxy must be 2d np.ndarray with (n, 4) shape")


def _validate_mask(mask: Any, n: int) -> None:
is_valid = mask is None or (
isinstance(mask, np.ndarray) and len(mask.shape) == 3 and mask.shape[0] == n
)
if not is_valid:
raise ValueError("mask must be 3d np.ndarray with (n, H, W) shape")


def validate_inference_callback(callback) -> None:
tmp_img = np.zeros((256, 256, 3), dtype=np.uint8)
res = callback(tmp_img)
if not isinstance(res, Detections):
raise ValueError("Callback function must return sv.Detection type")


def _validate_class_id(class_id: Any, n: int) -> None:
is_valid = class_id is None or (
isinstance(class_id, np.ndarray) and class_id.shape == (n,)
)
if not is_valid:
raise ValueError("class_id must be None or 1d np.ndarray with (n,) shape")


def _validate_confidence(confidence: Any, n: int) -> None:
is_valid = confidence is None or (
isinstance(confidence, np.ndarray) and confidence.shape == (n,)
)
if not is_valid:
raise ValueError("confidence must be None or 1d np.ndarray with (n,) shape")


def _validate_tracker_id(tracker_id: Any, n: int) -> None:
is_valid = tracker_id is None or (
isinstance(tracker_id, np.ndarray) and tracker_id.shape == (n,)
)
if not is_valid:
raise ValueError("tracker_id must be None or 1d np.ndarray with (n,) shape")


def is_data_equal(data_a: Dict[str, np.ndarray], data_b: Dict[str, np.ndarray]) -> bool:
"""
Compares the data payloads of two Detections instances.

Args:
data_a, data_b: The data payloads of the instances.

Returns:
True if the data payloads are equal, False otherwise.
"""
return set(data_a.keys()) == set(data_b.keys()) and all(
np.array_equal(data_a[key], data_b[key]) for key in data_a
)


def merge_data(
data_list: List[Dict[str, Union[np.ndarray, List]]],
) -> Dict[str, Union[np.ndarray, List]]:
"""
Merges the data payloads of a list of Detections instances.

Args:
data_list: The data payloads of the instances.

Returns:
A single data payload containing the merged data, preserving the original data
types (list or np.ndarray).

Raises:
ValueError: If data values within a single object have different lengths or if
dictionaries have different keys.
"""
if not data_list:
return {}

all_keys_sets = [set(data.keys()) for data in data_list]
if not all(keys_set == all_keys_sets[0] for keys_set in all_keys_sets):
raise ValueError("All data dictionaries must have the same keys to merge.")

for data in data_list:
lengths = [len(value) for value in data.values()]
if len(set(lengths)) > 1:
raise ValueError(
"All data values within a single object must have equal length.")

merged_data = {key: [] for key in all_keys_sets[0]}

for data in data_list:
for key in merged_data:
merged_data[key].append(data[key])

for key in merged_data:
if all(isinstance(item, list) for item in merged_data[key]):
merged_data[key] = list(chain.from_iterable(merged_data[key]))
elif all(isinstance(item, np.ndarray) for item in merged_data[key]):
ndim = merged_data[key][0].ndim
if ndim == 1:
merged_data[key] = np.hstack(merged_data[key])
elif ndim > 1:
merged_data[key] = np.vstack(merged_data[key])
else:
raise ValueError(f"Unexpected array dimension for key '{key}'.")
else:
raise ValueError(
f"Inconsistent data types for key '{key}'. Only np.ndarray and list "
f"types are allowed."
)

return merged_data


@dataclass
class Detections:
"""
Expand Down Expand Up @@ -162,11 +46,12 @@ class Detections:

def __post_init__(self):
n = len(self.xyxy)
_validate_xyxy(xyxy=self.xyxy, n=n)
_validate_mask(mask=self.mask, n=n)
_validate_class_id(class_id=self.class_id, n=n)
_validate_confidence(confidence=self.confidence, n=n)
_validate_tracker_id(tracker_id=self.tracker_id, n=n)
validate_xyxy(xyxy=self.xyxy, n=n)
validate_mask(mask=self.mask, n=n)
validate_class_id(class_id=self.class_id, n=n)
validate_confidence(confidence=self.confidence, n=n)
validate_tracker_id(tracker_id=self.tracker_id, n=n)
validate_data(data=self.data, n=n)

def __len__(self):
"""
Expand Down
3 changes: 1 addition & 2 deletions supervision/detection/tools/inference_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from supervision.detection.core import Detections, validate_inference_callback
from supervision.detection.core import Detections
from supervision.detection.utils import move_boxes
from supervision.utils.image import crop_image

Expand Down Expand Up @@ -60,7 +60,6 @@ def __init__(
self.iou_threshold = iou_threshold
self.callback = callback
self.thread_workers = thread_workers
validate_inference_callback(callback=callback)

def __call__(self, image: np.ndarray) -> Detections:
"""
Expand Down
127 changes: 126 additions & 1 deletion supervision/detection/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional, Tuple
from itertools import chain
from typing import List, Optional, Tuple, Dict, Union, Any

import cv2
import numpy as np
Expand Down Expand Up @@ -469,3 +470,127 @@ def sum_over_mask(indices: np.ndarray, axis: tuple) -> np.ndarray:
centroid_y = sum_over_mask(vertical_indices, aggregation_axis) / total_pixels

return np.column_stack((centroid_x, centroid_y)).astype(int)


def validate_xyxy(xyxy: Any, n: int) -> None:
is_valid = isinstance(xyxy, np.ndarray) and xyxy.shape == (n, 4)
if not is_valid:
raise ValueError("xyxy must be 2d np.ndarray with (n, 4) shape")


def validate_mask(mask: Any, n: int) -> None:
is_valid = mask is None or (
isinstance(mask, np.ndarray) and len(mask.shape) == 3 and mask.shape[0] == n
)
if not is_valid:
raise ValueError("mask must be 3d np.ndarray with (n, H, W) shape")


def validate_class_id(class_id: Any, n: int) -> None:
is_valid = class_id is None or (
isinstance(class_id, np.ndarray) and class_id.shape == (n,)
)
if not is_valid:
raise ValueError("class_id must be None or 1d np.ndarray with (n,) shape")


def validate_confidence(confidence: Any, n: int) -> None:
is_valid = confidence is None or (
isinstance(confidence, np.ndarray) and confidence.shape == (n,)
)
if not is_valid:
raise ValueError("confidence must be None or 1d np.ndarray with (n,) shape")


def validate_tracker_id(tracker_id: Any, n: int) -> None:
is_valid = tracker_id is None or (
isinstance(tracker_id, np.ndarray) and tracker_id.shape == (n,)
)
if not is_valid:
raise ValueError("tracker_id must be None or 1d np.ndarray with (n,) shape")


def validate_data(data: Dict[str, Union[np.ndarray, List]], n: int) -> None:
for key, value in data.items():
if isinstance(value, list):
if len(value) != n:
raise ValueError(f"Length of list for key '{key}' must be {n}")
elif isinstance(value, np.ndarray):
if value.ndim == 1 and value.shape[0] != n:
raise ValueError(f"Shape of np.ndarray for key '{key}' must be ({n},)")
elif value.ndim > 1 and value.shape[0] != n:
raise ValueError(
f"First dimension of np.ndarray for key '{key}' must have size {n}")
else:
raise ValueError(f"Value for key '{key}' must be a list or np.ndarray")


def is_data_equal(data_a: Dict[str, np.ndarray], data_b: Dict[str, np.ndarray]) -> bool:
"""
Compares the data payloads of two Detections instances.

Args:
data_a, data_b: The data payloads of the instances.

Returns:
True if the data payloads are equal, False otherwise.
"""
return set(data_a.keys()) == set(data_b.keys()) and all(
np.array_equal(data_a[key], data_b[key]) for key in data_a
)


def merge_data(
data_list: List[Dict[str, Union[np.ndarray, List]]],
) -> Dict[str, Union[np.ndarray, List]]:
"""
Merges the data payloads of a list of Detections instances.

Args:
data_list: The data payloads of the instances.

Returns:
A single data payload containing the merged data, preserving the original data
types (list or np.ndarray).

Raises:
ValueError: If data values within a single object have different lengths or if
dictionaries have different keys.
"""
if not data_list:
return {}

all_keys_sets = [set(data.keys()) for data in data_list]
if not all(keys_set == all_keys_sets[0] for keys_set in all_keys_sets):
raise ValueError("All data dictionaries must have the same keys to merge.")

for data in data_list:
lengths = [len(value) for value in data.values()]
if len(set(lengths)) > 1:
raise ValueError(
"All data values within a single object must have equal length.")

merged_data = {key: [] for key in all_keys_sets[0]}

for data in data_list:
for key in merged_data:
merged_data[key].append(data[key])

for key in merged_data:
if all(isinstance(item, list) for item in merged_data[key]):
merged_data[key] = list(chain.from_iterable(merged_data[key]))
elif all(isinstance(item, np.ndarray) for item in merged_data[key]):
ndim = merged_data[key][0].ndim
if ndim == 1:
merged_data[key] = np.hstack(merged_data[key])
elif ndim > 1:
merged_data[key] = np.vstack(merged_data[key])
else:
raise ValueError(f"Unexpected array dimension for key '{key}'.")
else:
raise ValueError(
f"Inconsistent data types for key '{key}'. Only np.ndarray and list "
f"types are allowed."
)

return merged_data
Loading