From 0911af701b9180cb5f88a8f6299d004d40390f8c Mon Sep 17 00:00:00 2001 From: Mikko Tukiainen Date: Tue, 3 Jun 2025 07:38:37 +0000 Subject: [PATCH 1/4] use real instead of complex tensors in Wan2.1 RoPE --- .../models/transformers/transformer_wan.py | 90 +++++++++++++------ 1 file changed, 65 insertions(+), 25 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index baa0ede4184e..4c6270eb820a 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -71,11 +71,24 @@ def __call__( if rotary_emb is not None: - def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): - dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64 - x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2))) - x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) - return x_out.type_as(hidden_states) + def apply_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + dtype = ( + torch.float32 + if hidden_states.device.type == "mps" + else torch.float64 + ) + x = hidden_states.view(*hidden_states.shape[:-1], -1, 2).to(dtype) + x1, x2 = x[..., 0], x[..., 1] + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out query = apply_rotary_emb(query, rotary_emb) key = apply_rotary_emb(key, rotary_emb) @@ -179,7 +192,11 @@ def forward( class WanRotaryPosEmbed(nn.Module): def __init__( - self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0 + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, ): super().__init__() @@ -189,36 +206,59 @@ def __init__( h_dim = w_dim = 2 * (attention_head_dim // 6) t_dim = attention_head_dim - h_dim - w_dim - - freqs = [] freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + + freqs_cos = [] + freqs_sin = [] + for dim in [t_dim, h_dim, w_dim]: - freq = get_1d_rotary_pos_embed( - dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype + freq_cos, freq_sin = get_1d_rotary_pos_embed( + dim, + max_seq_len, + theta, + use_real=True, + repeat_interleave_real=True, + freqs_dtype=freqs_dtype, ) - freqs.append(freq) - self.freqs = torch.cat(freqs, dim=1) + freqs_cos.append(freq_cos) + freqs_sin.append(freq_sin) + + self.freqs_cos = torch.cat(freqs_cos, dim=1) + self.freqs_sin = torch.cat(freqs_sin, dim=1) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - freqs = self.freqs.to(hidden_states.device) - freqs = freqs.split_with_sizes( - [ - self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), - self.attention_head_dim // 6, - self.attention_head_dim // 6, - ], - dim=1, + self.freqs_cos = self.freqs_cos.to(hidden_states.device) + self.freqs_sin = self.freqs_sin.to(hidden_states.device) + + split_sizes = [ + self.attention_head_dim - 2 * (self.attention_head_dim // 3), + self.attention_head_dim // 3, + self.attention_head_dim // 3, + ] + + freqs_cos = self.freqs_cos.split(split_sizes, dim=1) + freqs_sin = self.freqs_sin.split(split_sizes, dim=1) + + freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape( + 1, 1, ppf * pph * ppw, -1 + ) + freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape( + 1, 1, ppf * pph * ppw, -1 ) - freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) - freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) - freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) - freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) - return freqs + return freqs_cos, freqs_sin class WanTransformerBlock(nn.Module): From e64770b64967269f552c9159e0a764e8aa4cf19c Mon Sep 17 00:00:00 2001 From: Mikko Tukiainen Date: Tue, 3 Jun 2025 08:24:37 +0000 Subject: [PATCH 2/4] remove the redundant type conversion --- src/diffusers/models/transformers/transformer_wan.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 4c6270eb820a..3179d7e4c639 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -76,19 +76,14 @@ def apply_rotary_emb( freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, ): - dtype = ( - torch.float32 - if hidden_states.device.type == "mps" - else torch.float64 - ) - x = hidden_states.view(*hidden_states.shape[:-1], -1, 2).to(dtype) + x = hidden_states.view(*hidden_states.shape[:-1], -1, 2) x1, x2 = x[..., 0], x[..., 1] cos = freqs_cos[..., 0::2] sin = freqs_sin[..., 1::2] out = torch.empty_like(hidden_states) out[..., 0::2] = x1 * cos - x2 * sin out[..., 1::2] = x1 * sin + x2 * cos - return out + return out.type_as(hidden_states) query = apply_rotary_emb(query, rotary_emb) key = apply_rotary_emb(key, rotary_emb) From 68677687261f9602698034c09b6b493852774de5 Mon Sep 17 00:00:00 2001 From: Mikko Tukiainen Date: Tue, 3 Jun 2025 11:39:53 +0000 Subject: [PATCH 3/4] unpack rotary_emb --- src/diffusers/models/transformers/transformer_wan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 3179d7e4c639..1f3e7cbe1b69 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -85,8 +85,8 @@ def apply_rotary_emb( out[..., 1::2] = x1 * sin + x2 * cos return out.type_as(hidden_states) - query = apply_rotary_emb(query, rotary_emb) - key = apply_rotary_emb(key, rotary_emb) + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) # I2V task hidden_states_img = None From c19454cc60897b6d6b8cdd9a442f8010b4bbe919 Mon Sep 17 00:00:00 2001 From: Mikko Tukiainen Date: Tue, 3 Jun 2025 12:49:28 +0000 Subject: [PATCH 4/4] register rotary embedding frequencies as non-persistent buffers --- src/diffusers/models/transformers/transformer_wan.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 1f3e7cbe1b69..bbe84cdce698 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -218,17 +218,14 @@ def __init__( freqs_cos.append(freq_cos) freqs_sin.append(freq_sin) - self.freqs_cos = torch.cat(freqs_cos, dim=1) - self.freqs_sin = torch.cat(freqs_sin, dim=1) + self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) + self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - self.freqs_cos = self.freqs_cos.to(hidden_states.device) - self.freqs_sin = self.freqs_sin.to(hidden_states.device) - split_sizes = [ self.attention_head_dim - 2 * (self.attention_head_dim // 3), self.attention_head_dim // 3,