[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial commit adding support for from_lmm and specifically for Pal… #1221

Merged
merged 9 commits into from
May 24, 2024
Prev Previous commit
Next Next commit
fix(pre_commit): 🎨 auto format pre-commit hooks
  • Loading branch information
pre-commit-ci[bot] committed May 22, 2024
commit bf43d6566b8bfa9ac6c317bceb312c414268a60b
22 changes: 9 additions & 13 deletions supervision/detection/lmm.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
import re
import numpy as np
from enum import Enum
from typing import Dict, List, Tuple, Optional, Union, Any
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np


class LMM(Enum):
PALIGEMMA = 'paligemma'
PALIGEMMA = "paligemma"


REQUIRED_ARGUMENTS: Dict[LMM, List[str]] = {
LMM.PALIGEMMA: ['resolution_wh']
}
REQUIRED_ARGUMENTS: Dict[LMM, List[str]] = {LMM.PALIGEMMA: ["resolution_wh"]}

ALLOWED_ARGUMENTS: Dict[LMM, List[str]] = {
LMM.PALIGEMMA: ['resolution_wh', 'classes']
}
ALLOWED_ARGUMENTS: Dict[LMM, List[str]] = {LMM.PALIGEMMA: ["resolution_wh", "classes"]}


def validate_lmm_and_kwargs(lmm: Union[LMM, str], kwargs: Dict[str, Any]) -> LMM:
Expand All @@ -40,13 +37,12 @@ def validate_lmm_and_kwargs(lmm: Union[LMM, str], kwargs: Dict[str, Any]) -> LMM


def from_paligemma(
result: str,
resolution_wh: Tuple[int, int],
classes: Optional[List[str]] = None
result: str, resolution_wh: Tuple[int, int], classes: Optional[List[str]] = None
) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]:
w, h = resolution_wh
pattern = re.compile(
r'(?<!<loc\d{4}>)<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})> ([\w\s]+)')
r"(?<!<loc\d{4}>)<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})> ([\w\s]+)"
)
matches = pattern.findall(result)
matches = np.array(matches) if matches else np.empty((0, 5))

Expand Down
71 changes: 34 additions & 37 deletions test/detection/test_lmm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from typing import Tuple, Optional, List
from typing import List, Optional, Tuple

import numpy as np
import pytest

from supervision.detection.lmm import from_paligemma
Expand All @@ -13,114 +13,111 @@
"",
(1000, 1000),
None,
(np.empty((0, 4)), None, np.empty(0).astype(str))
(np.empty((0, 4)), None, np.empty(0).astype(str)),
), # empty response
(
"\n",
(1000, 1000),
None,
(np.empty((0, 4)), None, np.empty(0).astype(str))
(np.empty((0, 4)), None, np.empty(0).astype(str)),
), # new line response
(
"the quick brown fox jumps over the lazy dog.",
(1000, 1000),
None,
(np.empty((0, 4)), None, np.empty(0).astype(str))
(np.empty((0, 4)), None, np.empty(0).astype(str)),
), # response with no location
(
"<loc0256><loc0768><loc0768> cat",
(1000, 1000),
None,
(np.empty((0, 4)), None, np.empty(0).astype(str))
(np.empty((0, 4)), None, np.empty(0).astype(str)),
), # response with missing location
(
"<loc0256><loc0256><loc0768><loc0768><loc0768> cat",
(1000, 1000),
None,
(np.empty((0, 4)), None, np.empty(0).astype(str))
(np.empty((0, 4)), None, np.empty(0).astype(str)),
), # response with extra location
(
"<loc0256><loc0256><loc0768><loc0768>",
(1000, 1000),
None,
(np.empty((0, 4)), None, np.empty(0).astype(str))
(np.empty((0, 4)), None, np.empty(0).astype(str)),
), # response with no class
(
"<loc0256><loc0256><loc0768><loc0768> catt",
(1000, 1000),
['cat', 'dog'],
(np.empty((0, 4)), np.empty(0), np.empty(0).astype(str))
["cat", "dog"],
(np.empty((0, 4)), np.empty(0), np.empty(0).astype(str)),
), # response with invalid class
(
"<loc0256><loc0256><loc0768><loc0768> cat",
(1000, 1000),
None,
(
np.array([[250., 250., 750., 750.]]),
np.array([[250.0, 250.0, 750.0, 750.0]]),
None,
np.array(['cat']).astype(str)
)
np.array(["cat"]).astype(str),
),
), # correct response; no classes
(
"<loc0256><loc0256><loc0768><loc0768> black cat",
(1000, 1000),
None,
(
np.array([[250., 250., 750., 750.]]),
np.array([[250.0, 250.0, 750.0, 750.0]]),
None,
np.array(['black cat']).astype(np.dtype('U'))
)
np.array(["black cat"]).astype(np.dtype("U")),
),
), # correct response; no classes
(
"<loc0256><loc0256><loc0768><loc0768> cat ;",
(1000, 1000),
['cat', 'dog'],
["cat", "dog"],
(
np.array([[250., 250., 750., 750.]]),
np.array([[250.0, 250.0, 750.0, 750.0]]),
np.array([0]),
np.array(['cat']).astype(str)
)
np.array(["cat"]).astype(str),
),
), # correct response; with classes
(
"<loc0256><loc0256><loc0768><loc0768> cat ; <loc0256><loc0256><loc0768><loc0768> dog",
(1000, 1000),
['cat', 'dog'],
["cat", "dog"],
(
np.array([
[250., 250., 750., 750.],
[250., 250., 750., 750.]
]),
np.array([[250.0, 250.0, 750.0, 750.0], [250.0, 250.0, 750.0, 750.0]]),
np.array([0, 1]),
np.array(['cat', 'dog']).astype(np.dtype('U'))
)
np.array(["cat", "dog"]).astype(np.dtype("U")),
),
), # correct response; with classes
(
"<loc0256><loc0256><loc0768><loc0768> cat ; <loc0256><loc0256><loc0768> cat",
(1000, 1000),
['cat', 'dog'],
["cat", "dog"],
(
np.array([[250., 250., 750., 750.]]),
np.array([[250.0, 250.0, 750.0, 750.0]]),
np.array([0]),
np.array(['cat']).astype(str)
)
np.array(["cat"]).astype(str),
),
), # partially correct response; with classes
(
"<loc0256><loc0256><loc0768><loc0768> cat ; <loc0256><loc0256><loc0768><loc0768><loc0768> cat",
(1000, 1000),
['cat', 'dog'],
["cat", "dog"],
(
np.array([[250., 250., 750., 750.]]),
np.array([[250.0, 250.0, 750.0, 750.0]]),
np.array([0]),
np.array(['cat']).astype(str)
)
np.array(["cat"]).astype(str),
),
), # partially correct response; with classes
]
],
)
def test_from_paligemma(
result: str,
resolution_wh: Tuple[int, int],
classes: Optional[List[str]],
expected_results: Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]
expected_results: Tuple[np.ndarray, Optional[np.ndarray], np.ndarray],
) -> None:
result = from_paligemma(result=result, resolution_wh=resolution_wh, classes=classes)
np.testing.assert_array_equal(result[0], expected_results[0])
Expand Down