[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add automatic-mask-generation pipeline for Segment Anything Model (SAM) #22840

Merged
merged 68 commits into from
Apr 20, 2023

Conversation

ArthurZucker
Copy link
Collaborator
@ArthurZucker ArthurZucker commented Apr 18, 2023

What does this PR do?

This need the SAM model + rebasing once merged

from transformers import pipeline
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import time

generator = pipeline("automatic-mask-generation", device = 0)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"

dog_url = "/home/arthur_huggingface_co/transformers/Arthur/dog.jpg"
raw_image = Image.open(dog_url).convert("RGB")

start = time.time()
outputs = generator(raw_image, points_per_batch = 256, pred_iou_thresh=1)
print(f"point_batch_size : {256}, {time.time() - start}")

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    

plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
    show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()

plt.savefig("dog_results_2.png")

image

image

image

@HuggingFaceDocBuilderDev
Copy link
HuggingFaceDocBuilderDev commented Apr 18, 2023

The documentation is not available anymore as the PR was closed or merged.

@ArthurZucker ArthurZucker marked this pull request as ready for review April 19, 2023 11:51
Copy link
Collaborator
@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR! Left some initial comments. Also why the "automatic" in "automatic-mask-generation"? "mask-generation" is clear enough no?

src/transformers/pipelines/__init__.py Outdated Show resolved Hide resolved
src/transformers/pipelines/automatic_mask_generation.py Outdated Show resolved Hide resolved
src/transformers/pipelines/automatic_mask_generation.py Outdated Show resolved Hide resolved
src/transformers/pipelines/automatic_mask_generation.py Outdated Show resolved Hide resolved
src/transformers/pipelines/automatic_mask_generation.py Outdated Show resolved Hide resolved
Copy link
Contributor
@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good.

I think in general SAM could be just image-segmentation, but there seems to be a lot of specificities here with a lot of custom code, so making it standalone is ok for me now.

Custom code is marked as private so we can move later. And we could always make this pipeline be an alias of image-segmentation

src/transformers/pipelines/automatic_mask_generation.py Outdated Show resolved Hide resolved
Copy link
Collaborator
@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice - super exciting to see this pipeline!

+1 to all of Sylvain's comments about docstrings and variable names.

For the post processing that happens, I think most of the functionality should sit with the image processor. And the image processor should have dedicated methods for this. Other model's processors similarly filter boxes and perform RLE conversion.

For quite a few of the methods, I found it quite confusing what the functions were doing and the objects being handled. A lot of this should be resolved with Sylvain's suggestions. For some, splitting into more atomic functions and having consistent types e.g. not having to handle many different mask shapes would also help.

@@ -0,0 +1,615 @@
import math
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question (for @Narsil ?) - how come we don't have copyright headers for pipeline files?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know.

We can add them.
When I create new files I tend to copy/paste from something else, I might have missed some.

If copyright headers are important, shouldn't we have some sort of lint for them ?

src/transformers/pipelines/automatic_mask_generation.py Outdated Show resolved Hide resolved
src/transformers/pipelines/automatic_mask_generation.py Outdated Show resolved Hide resolved
src/transformers/pipelines/automatic_mask_generation.py Outdated Show resolved Hide resolved
src/transformers/pipelines/automatic_mask_generation.py Outdated Show resolved Hide resolved
masks = masks > mask_threshold
converted_boxes = _batched_mask_to_box(masks)

keep_mask = ~_is_box_near_crop_edge(converted_boxes, cropped_box_image, [0, 0, original_width, original_height])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is cropped_box_image (type and what does it represent)?

src/transformers/pipelines/mask_generation.py Outdated Show resolved Hide resolved
src/transformers/pipelines/mask_generation.py Outdated Show resolved Hide resolved
src/transformers/pipelines/mask_generation.py Outdated Show resolved Hide resolved
src/transformers/pipelines/mask_generation.py Outdated Show resolved Hide resolved
@younesbelkada younesbelkada changed the title [WIP] Add mg pipeline [WIP] Add automatic-mask-generation pipeline for Segment Anything Model (SAM) Apr 19, 2023
@younesbelkada younesbelkada changed the title [WIP] Add automatic-mask-generation pipeline for Segment Anything Model (SAM) Add automatic-mask-generation pipeline for Segment Anything Model (SAM) Apr 20, 2023
Copy link
Collaborator
@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would love for @amyeroberts to have a second look, but LGTM! Thanks!

Copy link
Collaborator
@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice update - structure looking a lit tidier!

A few main points:

  • There's still a lot of issues with the docstrings: missing, incomplete, wrong which need to be updated
  • A white_pixels check should still be part of the tests
  • Overall pipeline code looks good 👍 Just a few general nits there regarding argument values
  • I'm a bit concerned about the processing code. There's a lot of assumptions about the image types and shapes which I'm not sure are always correct. For a first pass of the postprocessing, we don't have to make it compatible in all cases, but it should be double checked and assumptions about inputs stated in the docstrings or comments.
  • Is the mask_threshold value right?

self.assertEqual(
nested_simplify(new_outupt, decimals=4),
[
{'mask': {'hash': '115ad19f5f', 'shape': (480, 640)}, 'scores': 1.0444},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a check here on the pixel value counts similar to the white_pixels in image segmentation. As mentioned before, if a single pixel changes the hash is completely different and without additional information test debugging is a lot harder. Checking the model output, the masks are binary, as so white_pixels check should be easy to add. If all values are 0, as before then this indicates an issue with the outputs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I will let @ArthurZucker do this in a follow up PR befrore the next release as discussed offline!

all_boxes = []
for model_output in model_outputs:
all_scores.append(model_output.pop("iou_scores"))
all_masks.extend(model_output.pop("masks"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extend difference here hasn't been addressed. See: #22840 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason behind that is that model_output.pop("masks") returns a list of masks, and post_process_for_mask_generation expects a single list of masks instead of nested lists. Therefore you need to call extend

src/transformers/pipelines/mask_generation.py Outdated Show resolved Hide resolved
src/transformers/pipelines/mask_generation.py Outdated Show resolved Hide resolved

def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh):
"""
Post processes mask that are automatically generated.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring missing:

  • Information about the output - what do the post processed outputs look like and represent?
  • Information about the input arguments and their types

src/transformers/models/sam/image_processing_sam.py Outdated Show resolved Hide resolved
src/transformers/models/sam/image_processing_sam.py Outdated Show resolved Hide resolved
src/transformers/models/sam/image_processing_sam.py Outdated Show resolved Hide resolved
src/transformers/models/sam/image_processing_sam.py Outdated Show resolved Hide resolved
younesbelkada and others added 6 commits April 20, 2023 17:34
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Copy link
Collaborator
@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks for iterating!

src/transformers/models/sam/image_processing_sam.py Outdated Show resolved Hide resolved
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@younesbelkada
Copy link
Contributor

Thank you all for your reviews!

@younesbelkada younesbelkada merged commit f143037 into huggingface:main Apr 20, 2023
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@ArthurZucker ArthurZucker deleted the add-mg-pipeline branch April 21, 2023 09:12
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…SAM) (huggingface#22840)

* cleanup

* updates

* more refactoring

* make style

* update inits

* support other inputs in base

* update based on review

Co-authored-by: Nicolas Patry <patry.nicolas@gmail.com>

* Update tests/pipelines/test_pipelines_automatic_mask_generation.py

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* update

* fixup

* TODO x and y to refactor, _h _w refactored here

* update docstring

* more nits

* style on these

* more doc fix

* rename variables

* update

* updates

* style

* update

* fix `_mask_to_rle_pytorch`

* styling

* fix ask to rle, wrong outputs

* add device arg

* update

* more updates, fix tets

* udpate

* update docstrings

* styling

* fixup

* add notebook on the docs

* update orginal sizes

* fix docstring

* updat condition on point_per-batch

* updates tests

* fix CI  test

* extend is required, append does not work!

* fixup

* fix CI tests

* whit pixels left

* address doc comments

* fix doc

* slow pipeline tests

* update auto init

* add revision

* make fixup

* update p!ipoeline tag when calling tests

* alphabeitcal order in inits

* fix copies

* last style nits

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* reformat docstring

* more reformat

* address most of the comments

* Update src/transformers/pipelines/mask_generation.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* final refactor

* Update src/transformers/models/sam/image_processing_sam.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fixup and fix slow tests

* revert

---------

Co-authored-by: Nicolas Patry <patry.nicolas@gmail.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants