-
Notifications
You must be signed in to change notification settings - Fork 6k
Use real-valued instead of complex tensors in Wan2.1 RoPE #11649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Use real-valued instead of complex tensors in Wan2.1 RoPE #11649
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, awesome work @mjkvaak-amd and thank you! Coincidentally, I was working on refactoring some of the rope code as well this week for compile compatibility, but you beat me to it :)
The changes looks good to me visually, but I'll quickly verify the numeric values ourselves as well.
Maybe returning a tuple from the rope layer can cause some issues with specific research repos that copy transformer implementation from diffusers but import internal layers directly, or folks using custom attention processor and expecting complex rope tensor (once this change is in main and next release). I think it should be fine as it'll be in a new release but LMK your thoughts @DN6
self.attention_head_dim // 6, | ||
], | ||
dim=1, | ||
self.freqs_cos = self.freqs_cos.to(hidden_states.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think doing it this way will cause a recompilation. We could probably just store as non-persistent buffer though with this refactor. The reason for not using buffer before was because it was a complex tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good thinking! I have added the proposed changes now.
On my end, I can confirm that the numerical outputs match on many arbitrary shapes. However, I do get different final results on full inference when comparing this branch to main.
(left is this branch, right is Trying to look into what could be the problem (possibly just something on my end) |
What does this PR do?
Avoids the complex tensors in Wan2.1 RoPE by using the real-valued cosine and sine instead. This boosts the performance of compiled models (inductor), where complex tensors are not supported.
Fixes # (issue)
UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
(ref how compile the model Wan-Video/Wan2.1#332). Using real-valued cosine and sine removes this warning and provides a noticeable boost in the throughput.Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
To verify that the proposed RoPE and utils result in identical and stable behavior compared to the original, I ran a 100-step training of Wan2.1 (image-to-video) with both the proposed (orange) and the original (blue) implementations
- the losses are on top of each other, but you can see there are two identical curves from the hovering tooltip.
Please also find the standalone tests for checking the equivalence below: