[MPS] Improve runtime complexity of roi_align
#9100
+81
−98
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
roi_align
on MPS has significantly inflated runtime complexity due to a bug in the looping behavior of the kernel. I've not found any other correctness issues with the current implementation, which closely follows the CUDA implementation. This PR fixes the runtime complexity, otherwise the kernel is semantically identical to before.Note that this PR switches the dispatching to
dispatchThreads
, which has a tighter build target set thandispatchThreadgroups
. RefNonuniform threadgroup size
in Metal feature set tables.Some other MPS kernels in vision is also likely affected.
Running the example code from pytorch/pytorch#124850 (comment) before:
and after
One concern I have with the approach I'm proposing here is numeric overflow of the index with large input sizes.
Fixes pytorch/pytorch#124850
cc @malfet @kulinseth @qqaatw