Skip to content

Conversation

@tyb0807
Copy link
Contributor

@tyb0807 tyb0807 commented Dec 23, 2025

Stacked PRs, do not merge.

The operations now accept both WaveTensorInRegister
(before conversion) and VectorOfAnyType (after conversion).

Updated type compatibility and verification logic to handle both tensor
and vector type combinations appropriately.

Fixes #624.

Implements elements per thread propagation for MMA operations.

Fixes iree-org#608.

Signed-off-by: tyb0807 <[email protected]>
Changes:
- ReadOp: Only propagate attribute to result (register), ignore memory
- WriteOp: Only validate/propagate with register operand, ignore memory

This fixes false positives where memory resharding was incorrectly
flagged as propagation errors.

Fixes iree-org#622.

Signed-off-by: tyb0807 <[email protected]>
The operations now accept both WaveTensorInRegister
(before conversion) and VectorOfAnyType (after conversion).

Updated type compatibility and verification logic to handle both tensor
and vector type combinations appropriately.

Fixes iree-org#624.

Signed-off-by: tyb0807 <[email protected]>
@tyb0807 tyb0807 requested a review from ftynse December 23, 2025 01:57
Arg<Variadic<WaveTensorType>, "Carried values">:$iter_args,
Arg<Variadic<WaveTensorType>, "Captured values">:$captures
// Accept both WaveTensorType (before PropagateElementsPerThread) and AnyVectorOfAnyRank (after)
Arg<Variadic<AnyTypeOf<[WaveTensorType, AnyVectorOfAnyRank]>>, "Carried values">:$iter_args,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't this be WaveTensorInRegisters? That constraint already accepts the the tensors with no address space, tensors in register address space and 1D vectors. And we most likely don't want any vector of any rank here, which would include scalable, 0d and other nonsense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because WaveTensorInRegisters doesn't work with Variadic. I think because it's a TypeConstraint (or something like that) and not a Type.

std::optional<int64_t> value = hyper.getSymbolValue(name);
#ifndef NDEBUG
if (!value) {
llvm::errs() << "symbol: " << name << "\n";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this?

Copy link
Contributor

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tgymnich , when reviewing stacked PRs (the target branch is not main), click on the last commit and only review that one to avoid making comments on things that should be addressed in other PRs.

@tgymnich
Copy link
Contributor

@tgymnich , when reviewing stacked PRs (the target branch is not main), click on the last commit and only review that one to avoid making comments on things that should be addressed in other PRs.

@ftynse could we instead just change the base to the PR below in the stack?

@ftynse
Copy link
Contributor

ftynse commented Dec 24, 2025

I can adapt to the style folks use. It did accidentally click on "squash and merge" in a thus stacked PR before, polluted the other branch and had to do a bunch of force-pushing and PR re-opening to fix that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support vector types in wave.iterate and wave.yield operations

3 participants