Skip to content

Use real-valued instead of complex tensors in Wan2.1 RoPE #11649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

mjkvaak-amd
Copy link
Contributor

@mjkvaak-amd mjkvaak-amd commented Jun 3, 2025

What does this PR do?

Avoids the complex tensors in Wan2.1 RoPE by using the real-valued cosine and sine instead. This boosts the performance of compiled models (inductor), where complex tensors are not supported.

Fixes # (issue)

Before submitting

To verify that the proposed RoPE and utils result in identical and stable behavior compared to the original, I ran a 100-step training of Wan2.1 (image-to-video) with both the proposed (orange) and the original (blue) implementations image - the losses are on top of each other, but you can see there are two identical curves from the hovering tooltip.

Please also find the standalone tests for checking the equivalence below:

import torch
from diffusers.models.embeddings import get_1d_rotary_pos_embed
from typing import *
from torch import nn


class WanRotaryPosEmbed(nn.Module):
    def __init__(
        self,
        attention_head_dim: int,
        patch_size: Tuple[int, int, int],
        max_seq_len: int,
        theta: float = 10000.0,
    ):
        super().__init__()

        self.attention_head_dim = attention_head_dim
        self.patch_size = patch_size
        self.max_seq_len = max_seq_len

        h_dim = w_dim = 2 * (attention_head_dim // 6)
        t_dim = attention_head_dim - h_dim - w_dim
        freqs_dtype = (
            torch.float32 if torch.backends.mps.is_available() else torch.float64
        )

        freqs_cos = []
        freqs_sin = []

        for dim in [t_dim, h_dim, w_dim]:
            freq_cos, freq_sin = get_1d_rotary_pos_embed(
                dim,
                max_seq_len,
                theta,
                use_real=True,
                repeat_interleave_real=True,
                freqs_dtype=freqs_dtype,
            )
            freqs_cos.append(freq_cos)
            freqs_sin.append(freq_sin)

        self.freqs_cos = torch.cat(freqs_cos, dim=1)
        self.freqs_sin = torch.cat(freqs_sin, dim=1)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        p_t, p_h, p_w = self.patch_size
        ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w

        self.freqs_cos = self.freqs_cos.to(hidden_states.device)
        self.freqs_sin = self.freqs_sin.to(hidden_states.device)

        split_sizes = [
            self.attention_head_dim - 2 * (self.attention_head_dim // 3),
            self.attention_head_dim // 3,
            self.attention_head_dim // 3,
        ]

        freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
        freqs_sin = self.freqs_sin.split(split_sizes, dim=1)

        freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
        freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
        freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)

        freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
        freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
        freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)

        freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(
            1, 1, ppf * pph * ppw, -1
        )
        freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(
            1, 1, ppf * pph * ppw, -1
        )

        return freqs_cos, freqs_sin


def apply_rotary_emb(
    hidden_states: torch.Tensor,
    freqs_cos: torch.Tensor,
    freqs_sin: torch.Tensor,
):
    dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
    x = hidden_states.view(*hidden_states.shape[:-1], -1, 2).to(dtype)
    x1, x2 = x[..., 0], x[..., 1]
    cos = freqs_cos[..., 0::2]
    sin = freqs_sin[..., 1::2]
    out = torch.empty_like(hidden_states)
    out[..., 0::2] = x1 * cos - x2 * sin
    out[..., 1::2] = x1 * sin + x2 * cos
    return out


class WanRotaryPosEmbedOriginal(nn.Module):
    def __init__(
        self,
        attention_head_dim: int,
        patch_size: Tuple[int, int, int],
        max_seq_len: int,
        theta: float = 10000.0,
    ):
        super().__init__()

        self.attention_head_dim = attention_head_dim
        self.patch_size = patch_size
        self.max_seq_len = max_seq_len

        h_dim = w_dim = 2 * (attention_head_dim // 6)
        t_dim = attention_head_dim - h_dim - w_dim

        freqs = []
        freqs_dtype = (
            torch.float32 if torch.backends.mps.is_available() else torch.float64
        )
        for dim in [t_dim, h_dim, w_dim]:
            freq = get_1d_rotary_pos_embed(
                dim,
                max_seq_len,
                theta,
                use_real=False,
                repeat_interleave_real=False,
                freqs_dtype=freqs_dtype,
            )
            freqs.append(freq)
        self.freqs = torch.cat(freqs, dim=1)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        p_t, p_h, p_w = self.patch_size
        ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w

        freqs = self.freqs.to(hidden_states.device)
        freqs = freqs.split_with_sizes(
            [
                self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
                self.attention_head_dim // 6,
                self.attention_head_dim // 6,
            ],
            dim=1,
        )

        freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
        freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
        freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
        freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(
            1, 1, ppf * pph * ppw, -1
        )
        return freqs


def apply_rotary_emb_original(hidden_states: torch.Tensor, freqs: torch.Tensor):
    dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
    x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2)))
    x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
    return x_out.type_as(hidden_states)


def test_rotary_pos_embed_value_equivalence():
    attention_head_dim = 12
    patch_size = (2, 2, 2)
    max_seq_len = 16
    batch, channels, frames, height, width = 1, attention_head_dim, 8, 8, 8
    hidden_states = torch.randn(batch, channels, frames, height, width)

    rope = WanRotaryPosEmbed(attention_head_dim, patch_size, max_seq_len)
    rope_orig = WanRotaryPosEmbedOriginal(attention_head_dim, patch_size, max_seq_len)

    # New returns (cos, sin), original returns complex
    cos, sin = rope(hidden_states)
    orig = rope_orig(hidden_states)  # shape: (1, 1, N, D)

    # Remove batch dims for comparison
    cos = cos.squeeze(0).squeeze(0)  # (N, D)
    sin = sin.squeeze(0).squeeze(0)  # (N, D)
    orig = orig.squeeze(0).squeeze(0)  # (N, D/2), complex
    cos_real = cos[:, 0::2]
    sin_real = sin[:, 1::2]

    # Reconstruct complex tensor
    recon = cos_real + 1j * sin_real

    # Compare real and imaginary parts
    assert torch.allclose(recon.real.float(), orig.real.float(), atol=1e-5)
    assert torch.allclose(recon.imag.float(), orig.imag.float(), atol=1e-5)


def test_rotary_emb_equivalence():
    attention_head_dim = 12
    patch_size = (2, 2, 2)
    max_seq_len = 16
    batch, channels, frames, height, width = 1, attention_head_dim, 8, 8, 8
    hidden_states = torch.randn(batch, channels, frames, height, width)

    rope = WanRotaryPosEmbed(attention_head_dim, patch_size, max_seq_len)
    rope_orig = WanRotaryPosEmbedOriginal(attention_head_dim, patch_size, max_seq_len)

    # Get rotary embeddings
    cos, sin = rope(hidden_states)
    freqs = rope_orig(hidden_states)

    # Prepare a fake attention input (B, H, N, D)
    B, H, N, D = cos.shape
    x = torch.randn(B, H, N, D, dtype=torch.float32)

    # Apply both rotary embeddings
    out_orig = apply_rotary_emb_original(x, freqs)
    out_real = apply_rotary_emb(x, cos, sin)

    # Check equivalence
    assert torch.allclose(
        out_real, out_orig, atol=1e-5
    ), "Real-valued rotary embedding does not match original complex version"

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, awesome work @mjkvaak-amd and thank you! Coincidentally, I was working on refactoring some of the rope code as well this week for compile compatibility, but you beat me to it :)

The changes looks good to me visually, but I'll quickly verify the numeric values ourselves as well.

Maybe returning a tuple from the rope layer can cause some issues with specific research repos that copy transformer implementation from diffusers but import internal layers directly, or folks using custom attention processor and expecting complex rope tensor (once this change is in main and next release). I think it should be fine as it'll be in a new release but LMK your thoughts @DN6

self.attention_head_dim // 6,
],
dim=1,
self.freqs_cos = self.freqs_cos.to(hidden_states.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think doing it this way will cause a recompilation. We could probably just store as non-persistent buffer though with this refactor. The reason for not using buffer before was because it was a complex tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good thinking! I have added the proposed changes now.

@a-r-r-o-w
Copy link
Member

On my end, I can confirm that the numerical outputs match on many arbitrary shapes. However, I do get different final results on full inference when comparing this branch to main.

output.mp4
output2.mp4

(left is this branch, right is diffusers:main; both use the example pipeline code with same seed)

Trying to look into what could be the problem (possibly just something on my end)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants