[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented from_yolo_nas for keypoints #1138

Merged
merged 10 commits into from
May 2, 2024
Merged

Conversation

LinasKo
Copy link
Collaborator
@LinasKo LinasKo commented Apr 25, 2024

Description

Implementation of from_yolo_nas for KeyPoints.

⚠️ Missing features:

  • Does not include batch image input (I don't think we want that right now).
  • class_id array entries are hardcoded to 0; class_name entries as "person". Results object from NAS has a class_names, but it is None. This seems to be set during training, and used in other model types. Let's revisit this after we decide how exactly we want to structure class_id and class_names in KeyPoints.

Type of change

  • New feature (non-breaking change which adds functionality)

How has this change been tested, please provide a testcase or example of how you tested the change?

Prerequisites:

pip install super-gradients

Test code:

import numpy as np
import cv2

import supervision as sv
from supervision.keypoint.skeletons import Skeleton
from supervision.assets import download_assets, VideoAssets

import super_gradients


# Available models: https://github.com/Deci-AI/super-gradients/blob/master/YOLONAS-POSE.md
yolo_nas = super_gradients.training.models.get(
    "yolo_nas_pose_s", pretrained_weights="coco_pose").to("cuda")

download_assets(VideoAssets.PEOPLE_WALKING)

# I found this important to tweak, regardless of model size
CONFIDENCE_THRESHOLD = 0.1

# cap = cv2.VideoCapture(0)
cap = cv2.VideoCapture(VideoAssets.PEOPLE_WALKING.value)
i = 0
while True:
    ret, frame = cap.read()
    if not ret:
        continue

    # TODO: I assume we're not supporting batch-image inputs yet. YOLO NAS does.

    result = yolo_nas.predict(frame, conf=CONFIDENCE_THRESHOLD)
    keypoints = sv.KeyPoints.from_yolo_nas(result)

    ann_point_large = sv.VertexAnnotator(color=sv.Color.ROBOFLOW, radius=5)
    ann_point_small = sv.VertexAnnotator(color=sv.Color.WHITE, radius=3)

    # Option 1: Use a predefined skeleton
    ann_skeleton = sv.EdgeAnnotator(
        color=sv.Color.ROBOFLOW,
        thickness=5,
        edges=Skeleton.YOLO_NAS.value
    )

    # Option 2: No skeleton
    # ann_skeleton = sv.EdgeAnnotator(
    #     color=sv.Color.ROBOFLOW,
    #     thickness=5,
    #     edges=[]
    # )

    # Option 3: Figure out automatically
    # ann_skeleton = sv.EdgeAnnotator(
    #     color=sv.Color.ROBOFLOW,
    #     thickness=5
    # )

    # Option 4: Take a guess (connect sequential points)
    # TODO: remove YOLO_NAS from Skeleton before running
    # ann_skeleton = sv.EdgeAnnotator(
    #     color=sv.Color.ROBOFLOW,
    #     thickness=5
    # )

    # Draw
    try:
        ann_skeleton.annotate(frame, keypoints)
        ann_point_large.annotate(frame, keypoints)
        ann_point_small.annotate(frame, keypoints)
    except Exception as e:
        print("Caught exception while annotating: \n", e)

    frame = cv2.resize(frame, (1024, 680))
    cv2.imshow('frame', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

Docs

  • Docs updated? What were the changes:

Linas Kondrackis added 3 commits April 25, 2024 12:29
@LinasKo LinasKo requested a review from SkalskiP April 25, 2024 10:02
supervision/keypoint/core.py Outdated Show resolved Hide resolved

xy = yolo_nas_results.prediction.poses[:, :, :2]
confidence = yolo_nas_results.prediction.poses[:, :, 2]
class_id = [0] * len(xy)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not include any information not included in the result object. If class_id and class_name are not there, we should not add them. YOLO NAS allows you to train a custom pose estimation model: https://www.youtube.com/watch?v=J83ZvWfxjoA, so we should not assume we will get people or all classes are the same.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a moderate attempt, I found no way to add class_names or labels when training a pose net for Yolo NAS. I've mocked the results by using a detections network.

supervision/keypoint/core.py Outdated Show resolved Hide resolved
supervision/keypoint/skeletons.py Outdated Show resolved Hide resolved
Linas Kondrackis and others added 4 commits April 25, 2024 15:00
* Still hardcoded ID to -1 and name to "" if not provided - this lets us
  stick to (n,) shape.
* Found no way to add class_names to pose tracker, so used the response
  of YOLO NAS detection instead, to check what the response looks like.
@LinasKo
Copy link
Collaborator Author
LinasKo commented May 1, 2024

Ready for review:
https://colab.research.google.com/drive/1RGmU9lwXaffYWR9a1MegC310SWLA7Rw_?usp=sharing

I've removed hardcoding of person, but kept a special value of -1 and name "" when no label or class_name is given by the model. Finding no way to add class names when training a pose net, I've verified the structure of label / class_names by calling a detection model.

The special fields for -1 match what we have in Detections and conform yo our current validators + which assert that given we have (n, ...) xy points, we'll always have (n, ) of class_id. If that's not something we no longer wish to support, let's make a separate PR.

@SkalskiP
Copy link
Collaborator
SkalskiP commented May 2, 2024

Hi @LinasKo 👋🏻, let's keep the KeyPoints behavior as close as possible to Detections. The Detections.class_id is actually a non-mandatory field and it can be None. Currently, in the develop branch, there is no scenario where you would get -1 as class_id, the same goes for tracker_id. The current validate_class_id implementation allows for class_id to be None. So, to summarize, please set class_id as None if the value is unknown.

The same applies to class names. If we don't have this value, simply don't include the key "class_names" in the data dict. This is actually the default behavior of connectors in Detections. Most of them do not return class names.

@LinasKo
Copy link
Collaborator Author
LinasKo commented May 2, 2024

Done. Tested and class_id is now None when labels aren't given; data is {}.

Class names and IDs still appear when input has labels and names.

@SkalskiP
Copy link
Collaborator
SkalskiP commented May 2, 2024

Awesome! I tested as well. Merging!

@SkalskiP SkalskiP merged commit 177ab6a into develop May 2, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants