-
Notifications
You must be signed in to change notification settings - Fork 24
[water] Add support for vector types in wave.iterate and wave.yield operations #625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Implements elements per thread propagation for MMA operations. Fixes iree-org#608. Signed-off-by: tyb0807 <[email protected]>
Changes: - ReadOp: Only propagate attribute to result (register), ignore memory - WriteOp: Only validate/propagate with register operand, ignore memory This fixes false positives where memory resharding was incorrectly flagged as propagation errors. Fixes iree-org#622. Signed-off-by: tyb0807 <[email protected]>
The operations now accept both WaveTensorInRegister (before conversion) and VectorOfAnyType (after conversion). Updated type compatibility and verification logic to handle both tensor and vector type combinations appropriately. Fixes iree-org#624. Signed-off-by: tyb0807 <[email protected]>
| Arg<Variadic<WaveTensorType>, "Carried values">:$iter_args, | ||
| Arg<Variadic<WaveTensorType>, "Captured values">:$captures | ||
| // Accept both WaveTensorType (before PropagateElementsPerThread) and AnyVectorOfAnyRank (after) | ||
| Arg<Variadic<AnyTypeOf<[WaveTensorType, AnyVectorOfAnyRank]>>, "Carried values">:$iter_args, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't this be WaveTensorInRegisters? That constraint already accepts the the tensors with no address space, tensors in register address space and 1D vectors. And we most likely don't want any vector of any rank here, which would include scalable, 0d and other nonsense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because WaveTensorInRegisters doesn't work with Variadic. I think because it's a TypeConstraint (or something like that) and not a Type.
| std::optional<int64_t> value = hyper.getSymbolValue(name); | ||
| #ifndef NDEBUG | ||
| if (!value) { | ||
| llvm::errs() << "symbol: " << name << "\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why remove this?
ftynse
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tgymnich , when reviewing stacked PRs (the target branch is not main), click on the last commit and only review that one to avoid making comments on things that should be addressed in other PRs.
|
I can adapt to the style folks use. It did accidentally click on "squash and merge" in a thus stacked PR before, polluted the other branch and had to do a bunch of force-pushing and PR re-opening to fix that. |
Stacked PRs, do not merge.
The operations now accept both WaveTensorInRegister
(before conversion) and VectorOfAnyType (after conversion).
Updated type compatibility and verification logic to handle both tensor
and vector type combinations appropriately.
Fixes #624.