Skip to content

[MPS] Improve runtime complexity of roi_align #9100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 75 additions & 84 deletions torchvision/csrc/ops/mps/mps_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,105 +225,96 @@ kernel void nms<DTYPE ## 4, DTYPE>( \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tid2 [[thread_position_in_threadgroup]]);

template<typename T, typename integer_t>
template <typename T, typename integer_t>
kernel void roi_align(
constant T * input [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * output [[buffer(2)]],
constant int64_t & output_size [[buffer(3)]],
constant float & spatial_scale [[buffer(3)]],
constant int64_t & channels [[buffer(4)]],
constant int64_t & height [[buffer(5)]],
constant int64_t & width [[buffer(6)]],
constant int64_t & pooled_height [[buffer(7)]],
constant int64_t & pooled_width [[buffer(8)]],
constant int64_t & sampling_ratio [[buffer(9)]],
constant bool & aligned [[buffer(10)]],
constant float & spatial_scale [[buffer(11)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c = (index / pooled_width / pooled_height) % channels;
integer_t n = index / pooled_width / pooled_height / channels;

constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];

// Do not using rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;

T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
}

T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

constant T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;

// We use roi_bin_grid to sample the grid and mimic integral
integer_t roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
integer_t roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);

// We do average (integral) pooling inside a bin
// When the grid is empty, output zeros.
const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast<integer_t>(1)); // e.g. = 4

T output_val = 0.;
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
uint index [[thread_position_in_grid]])
{
// Decode linear index into (n, c, ph, pw)
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c = (index / pooled_width / pooled_height) % channels;
integer_t n = index / (pooled_width * pooled_height * channels);

constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = static_cast<integer_t>(offset_rois[0]);

// Do not using rounding; this implementation detail is critical
T offset = aligned ? static_cast<T>(0.5) : static_cast<T>(0.0);
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;

T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;

if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, static_cast<T>(1.0));
roi_height = max(roi_height, static_cast<T>(1.0));
}

T val = bilinear_interpolate(offset_input, height, width, y, x, index);
output_val += val;
}
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);

constant T* offset_input = input + (roi_batch_ind * channels + c) * height * width;

// We use roi_bin_grid to sample the grid and mimic integral
integer_t roi_bin_grid_h = sampling_ratio > 0
? sampling_ratio
: static_cast<integer_t>(ceil(roi_height / static_cast<T>(pooled_height)));
integer_t roi_bin_grid_w = sampling_ratio > 0
? sampling_ratio
: static_cast<integer_t>(ceil(roi_width / static_cast<T>(pooled_width)));

// We do average (integral) pooling inside a bin
// When the grid is empty, output zeros.
const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast<integer_t>(1));
T output_val = static_cast<T>(0.0);

for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) {
T y = roi_start_h + static_cast<T>(ph) * bin_size_h +
(static_cast<T>(iy) + static_cast<T>(0.5)) * bin_size_h / static_cast<T>(roi_bin_grid_h);
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
T x = roi_start_w + static_cast<T>(pw) * bin_size_w +
(static_cast<T>(ix) + static_cast<T>(0.5)) * bin_size_w / static_cast<T>(roi_bin_grid_w);

T val = bilinear_interpolate(offset_input, height, width, y, x, index);
output_val += val;
}
output_val /= count;

output[index] = output_val;
}

output_val /= count;
output[index] = output_val;
}

#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("roi_align_" #DTYPE)]] \
kernel void roi_align<DTYPE, INT_DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
constant int64_t & output_size [[buffer(3)]], \
constant int64_t & channels [[buffer(4)]], \
constant int64_t & height [[buffer(5)]], \
constant int64_t & width [[buffer(6)]], \
constant int64_t & pooled_height [[buffer(7)]], \
constant int64_t & pooled_width [[buffer(8)]], \
constant int64_t & sampling_ratio [[buffer(9)]], \
constant bool & aligned [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("roi_align_" #DTYPE)]] \
kernel void roi_align<DTYPE, INT_DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
constant float & spatial_scale [[buffer(3)]], \
constant int64_t & channels [[buffer(4)]], \
constant int64_t & height [[buffer(5)]], \
constant int64_t & width [[buffer(6)]], \
constant int64_t & pooled_height [[buffer(7)]], \
constant int64_t & pooled_width [[buffer(8)]], \
constant int64_t & sampling_ratio [[buffer(9)]], \
constant bool & aligned [[buffer(10)]], \
uint index [[thread_position_in_grid]]);

template<typename T, typename integer_t>
kernel void roi_align_backward(
Expand Down Expand Up @@ -1005,7 +996,7 @@ kernel void ps_roi_pool_backward<DTYPE, INT_DTYPE>( \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & channels_out [[buffer(10)]], \
constant int64_t & channels_out [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
Expand Down
20 changes: 6 additions & 14 deletions torchvision/csrc/ops/mps/roi_align_kernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);

const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);

auto threadsPerGrid = MTLSizeMake(output_size, 1, 1);
auto threadsPerThreadgroup =
MTLSizeMake(std::min(static_cast<int64_t>(visionPSO.maxTotalThreadsPerThreadgroup), output_size), 1, 1);

// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});

Expand All @@ -68,24 +68,16 @@
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];

[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:3];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];

// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}

MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
[computeEncoder dispatchThreads:threadsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];

getMPSProfiler().endProfileKernel(visionPSO);
}
Expand Down