diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp index b9a3429e37b88..c76f8d77dff55 100644 --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -427,13 +427,21 @@ namespace { template void mgpuGetMemRefDataAndShape(void *raw_descriptor, char **addr, - uint64_t *globalDim) { + uint64_t *globalDim, uint64_t *globalStrides, + const CUtensorMapDataType tensorDataType) { auto descriptor = reinterpret_cast *>(raw_descriptor); *addr = descriptor->data; for (int i = 0; i < rank; ++i) { globalDim[i] = static_cast(descriptor->sizes[rank - i - 1]); } + static constexpr int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2, + 4, 8, 2, 4, 4, 4}; + // TODO(grypp): Check that the minormost stride is equal to the element size. + for (int i = 0; i < rank - 1; ++i) { + globalStrides[i] = static_cast( + descriptor->strides[rank - i - 2] * elementSizeInBytes[tensorDataType]); + } } } // namespace @@ -457,19 +465,24 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref( char *globalAddress = nullptr; switch (tensorRank) { case 1: - mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim); + mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim, + globalStrides, tensorDataType); break; case 2: - mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim); + mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim, + globalStrides, tensorDataType); break; case 3: - mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim); + mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim, + globalStrides, tensorDataType); break; case 4: - mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim); + mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim, + globalStrides, tensorDataType); break; case 5: - mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim); + mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim, + globalStrides, tensorDataType); break; default: fprintf( @@ -478,17 +491,10 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref( return NULL; } - static const int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2, - 4, 8, 2, 4, 4, 4}; for (int64_t r = 0; r < tensorRank; ++r) { - elementStrides[r] = uint32_t(1); boxDim[r] = static_cast(inputBoxDims[tensorRank - r - 1]); } - globalStrides[0] = globalDim[0] * elementSizeInBytes[tensorDataType]; - for (int r = 1; r < tensorRank - 1; r++) - globalStrides[r] = globalStrides[r - 1] * globalDim[r]; - ScopedContext scopedContext; mgpuTensorMapEncodeTiled(&tensorMap, tensorDataType, tensorRank32, globalAddress, globalDim, globalStrides, boxDim,