-
Notifications
You must be signed in to change notification settings - Fork 284
Build mxfp4 kernel for sm120a #2285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2285
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit efb5860 with merge base 60c583e ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
The first thing that comes to mind is that example is doing NVfp4 where all our recipes are doing MXfp4, e.g. https://github.com/pytorch/ao/pull/2285/files#diff-e155558499c3b1fbab1b5d3b60f032bf1e636908a8ef50a1de33bff518107019R240-R241 needs to change as well. For inference we have MXFP8 and MXFP4 support I am planning to add an NVFP4 scaling recipe next, that being said I would imagine that MXFP4 is supported on 5090.. cc @syed-ahmed |
I noticed that as well
😭 |
Per cutlass docs, I believe MXFP4 is supported in 5090: https://github.com/NVIDIA/cutlass/blob/9d165a3b8ef446a7ff3db198413f82bcb83f46fe/media/docs/cpp/blackwell_functionality.md#blackwell-sm120-gemms However note the section that talks about the differences with sm100. So it's possible we need more changes to the kernel in torch ao. Also what CUDA version are you using? I'd assume you'd need a fairly recent CUDA version. I'll try to guide more next week. |
@syed-ahmed I'm using CUDA 12.9 The strange thing is that the cutlass example works, but the one in torchao doesn't. I carefully compared the two, and I don't spot any difference in the template arguments. |
How about the test? Are the inputs similar to the cutlass example? |
), | ||
] | ||
# Remove from main sources to prevent compilation with other architectures | ||
sources = [ | ||
s for s in sources if os.path.basename(s) != "mx_fp_cutlass_kernels.cu" | ||
s | ||
for s in sources |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit I wouldn't help a little helper func defined above to do all this stripping
for file in so_files: | ||
# only load architecture-specific target if the current GPU matches that target |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale, | ||
at::Tensor& b_scale, at::Tensor& out, int M, int K, int N) { | ||
|
||
using MmaTileShape = Shape<_128,_128,_128>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haven't looked in depth too much at what args are avail curious how much tuning is needed here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks really good, could you add a test to test/prototype/mx_formats/test_mx_mm.py
even if it wound't be exercised in ci, As well if you have any perf numbers that would be great
Update (2025/06/11)
I narrowed down the issue to template - if the kernel is inside a templated function, even if I don't use any template arguments, I will get the runtime error below (
cudaFuncSetAttribute() returned error: invalid resource handle
). It might be an issue with cutlass or my environment (nvcc version, compiler...).Hence, the solution is to create a separate source file for sm120a, without any templated functions. When we support nvfp4 in the future, we can either manually duplicate the code again, use macro, or have a python script to codegen the cutlass kernel creation.
Other details of this PR:
sm120a
extensionOther alternatives that I have considered for the torch library loading logic:
setuptools.Extension
's limitation, sm100a and sm120a kernels must stay in separate shared library files. This eliminates the option of doing runtime check in C++.mx_fp4_bf16_sm100a
andmx_fp4_bf16_sm120a
), and dispatch the correct op in PythonOriginal (2025/05/31)
Just making some quick changes here to see if I can build mxfp4 kernel on 5090 (sm120). Eventually this will be put under
torchao._C_cutlass_120a
?Setting
-DCUTLASS_DEBUG_TRACE_LEVEL=1
so I can see debug trace.To build (using
torch==2.8.0.dev20250530+cu128
)Running
pytest test/prototype/mx_formats/test_mx_mm.py -v
cudaFuncSetAttribute() returned error: invalid resource handle
means that the function is invalid? https://github.com/NVIDIA/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/gemm/device/gemm_universal_adapter.h#L338, which is quite strange...For reference, I can build and run the example from Cutlass here https://github.com/NVIDIA/cutlass/blob/v3.9.2/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu. The changes in this PR has been taken from this example. When building with
CUTLASS_DEBUG_TRACE_LEVEL=1
, there are also warnings insm90_gemm_tma_warpspecialized_cooperative.hpp
, so that is probably not the issue.@drisspg
cc @alexsamardzic in case you faced this error with Cutlass before