Skip to content

[MPS] Improve runtime complexity of roi_align #9100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

hvaara
Copy link

@hvaara hvaara commented Jun 8, 2025

roi_align on MPS has significantly inflated runtime complexity due to a bug in the looping behavior of the kernel. I've not found any other correctness issues with the current implementation, which closely follows the CUDA implementation. This PR fixes the runtime complexity, otherwise the kernel is semantically identical to before.

Note that this PR switches the dispatching to dispatchThreads, which has a tighter build target set than dispatchThreadgroups. Ref Nonuniform threadgroup size in Metal feature set tables.

Some other MPS kernels in vision is also likely affected.

Running the example code from pytorch/pytorch#124850 (comment) before:

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls                                                                      Input Shapes
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------
                 model_inference         0.02%       6.412ms       100.00%       41.913s       41.913s             1                                                                                []
                     aten::where         0.00%       4.373us        80.19%       33.611s        8.403s             4                                                                          [[1000]]
             aten::nonzero_numpy         0.00%      15.335us        80.19%       33.611s        8.403s             4                                                                          [[1000]]
                   aten::nonzero        80.18%       33.605s        80.19%       33.611s        8.403s             4                                                                          [[1000]]
                     aten::where         0.00%       7.375us         2.55%        1.067s     533.698ms             2                                                                          [[4507]]
             aten::nonzero_numpy         0.00%      11.042us         2.55%        1.067s     533.695ms             2                                                                          [[4507]]
                   aten::nonzero         2.31%     969.133ms         2.55%        1.067s     533.679ms             2                                                                          [[4507]]
                      aten::topk         2.53%        1.062s         2.53%        1.062s        1.062s             1                                                     [[1, 120000], [], [], [], []]
                torchvision::nms         0.00%      52.208us         2.39%        1.004s        1.004s             1                                                               [[21, 4], [21], []]
                      aten::sort         2.39%     999.630ms         2.39%     999.635ms     999.635ms             1                                                                [[21], [], [], []]
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------
Self CPU time total: 41.913s

and after

--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls                                                                      Input Shapes
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------
                 model_inference         0.88%       4.364ms       100.00%     493.862ms     493.862ms             1                                                                                []
                torchvision::nms        15.95%      78.782ms        17.20%      84.925ms      84.925ms             1                                                           [[4507, 4], [4507], []]
                     aten::where         0.00%       2.957us        11.38%      56.185ms      14.046ms             4                                                                          [[1000]]
             aten::nonzero_numpy         0.00%       7.379us        11.38%      56.182ms      14.045ms             4                                                                          [[1000]]
                   aten::nonzero        10.26%      50.684ms        11.37%      56.146ms      14.036ms             4                                                                          [[1000]]
                    aten::conv2d         0.00%       5.417us         6.39%      31.548ms      31.548ms             1                             [[1, 3, 800, 800], [64, 3, 7, 7], [], [], [], [], []]
               aten::convolution         0.00%       9.041us         6.39%      31.543ms      31.543ms             1                     [[1, 3, 800, 800], [64, 3, 7, 7], [], [], [], [], [], [], []]
              aten::_convolution         0.00%      12.542us         6.39%      31.534ms      31.534ms             1     [[1, 3, 800, 800], [64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []]
          aten::_mps_convolution         6.38%      31.520ms         6.38%      31.521ms      31.521ms             1                             [[1, 3, 800, 800], [64, 3, 7, 7], [], [], [], [], []]
          torchvision::roi_align         5.88%      29.036ms         5.88%      29.047ms      29.047ms             1                                [[1, 256, 200, 200], [960, 5], [], [], [], [], []]
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  --------------------------------------------------------------------------------
Self CPU time total: 493.862ms

One concern I have with the approach I'm proposing here is numeric overflow of the index with large input sizes.

Fixes pytorch/pytorch#124850

cc @malfet @kulinseth @qqaatw

Copy link

pytorch-bot bot commented Jun 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9100

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@hvaara hvaara force-pushed the fix-roi-align-mps branch from c4b01c0 to 34d749d Compare June 11, 2025 14:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

aten::nonzero calls taking a huge amount of time when using MPS backend vs CPU
2 participants