Description
Describe the bug
When running the WanPipeline from diffusers on an MPS device, the pipeline fails with the error:
TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype.
Investigation indicates that in the transformer component (specifically in the rotary positional embedding function of WanRotaryPosEmbed), frequency tensors are computed using torch.float64 (and then converted to complex via torch.view_as_complex). This produces a ComplexDouble tensor (i.e. torch.complex128), which the MPS backend does not support.
Steps to Reproduce:
On an Apple Silicon Mac with MPS enabled, use diffusers (version 0.33.0.dev0) along with a recent PyTorch nightly (2.7.0.dev20250305).
Load a model pipeline as follows:
from diffusers import AutoencoderKLWan, WanPipeline
vae = AutoencoderKLWan.from_pretrained("<model_path>", subfolder="vae", torch_dtype=torch.float32).to("mps")
pipe = WanPipeline.from_pretrained("<model_path>", vae=vae, torch_dtype=torch.float32).to("mps")
Encode prompts and call the pipeline for inference.
The error occurs during the transformer’s forward pass in the rotary embedding function when it calls torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(...)
.
Expected Behavior:
When running on MPS, all computations (including the construction of rotary positional encodings) should use single-precision floats (and their corresponding complex type, i.e. torch.cfloat). In other words, the pipeline should ensure that no operations create or convert to ComplexDouble (complex128) tensors on MPS.
Workaround:
A temporary fix is to patch the helper functions used in computing the rotary embeddings so that they force the use of torch.float32. For instance, one workaround is to override the function (e.g., get_1d_rotary_pos_embed) that computes the frequency tensor so that it uses freqs_dtype=torch.float32 regardless of defaults. Additionally, patching WanRotaryPosEmbed.forward to cast its output to float32 (if it’s float64) avoids this error.
Reproduction
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import torch
# ------------------------------------------------------------------------------
# 1. Patch torch.view_as_complex to avoid creating ComplexDouble on MPS.
_orig_view_as_complex = torch.view_as_complex
def patched_view_as_complex(tensor):
if tensor.device.type == "mps" and tensor.dtype == torch.float64:
tensor = tensor.to(torch.float32)
return _orig_view_as_complex(tensor)
torch.view_as_complex = patched_view_as_complex
# ------------------------------------------------------------------------------
# 2. Patch get_1d_rotary_pos_embed so that it always computes frequencies as float32.
try:
from diffusers.models.transformers.embeddings import get_1d_rotary_pos_embed as original_get_1d_rotary_pos_embed
import diffusers.models.transformers.embeddings as embeddings_mod
except ImportError:
from diffusers.models.embeddings import get_1d_rotary_pos_embed as original_get_1d_rotary_pos_embed
import diffusers.models.embeddings as embeddings_mod
def patched_get_1d_rotary_pos_embed(dim, max_seq_len, theta, use_real, repeat_interleave_real, freqs_dtype):
return original_get_1d_rotary_pos_embed(
dim, max_seq_len, theta, use_real, repeat_interleave_real, freqs_dtype=torch.float32
)
embeddings_mod.get_1d_rotary_pos_embed = patched_get_1d_rotary_pos_embed
# ------------------------------------------------------------------------------
# 3. Patch WanRotaryPosEmbed.forward to ensure its output is float32.
from diffusers.models.transformers.transformer_wan import WanRotaryPosEmbed
_orig_rope_forward = WanRotaryPosEmbed.forward
def patched_rope_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
result = _orig_rope_forward(self, hidden_states)
if hidden_states.device.type == "mps" and result.dtype == torch.float64:
result = result.to(torch.float32)
return result
WanRotaryPosEmbed.forward = patched_rope_forward
# ------------------------------------------------------------------------------
# Model setup.
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.utils import export_to_video
model_id = "~/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
vae.to("mps")
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.float32)
pipe.to("mps")
pipe.enable_attention_slicing()
# ------------------------------------------------------------------------------
# Define prompts.
prompt = "A cat walks on the grass, realistic"
negative_prompt = (
"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, "
"worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, "
"deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
)
# ------------------------------------------------------------------------------
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=True,
num_videos_per_prompt=1,
max_sequence_length=226,
device="mps",
dtype=torch.float32
)
# ------------------------------------------------------------------------------
# Generate video frames.
output = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
height=480,
width=832,
num_frames=81,
guidance_scale=5.0
).frames[0]
# ------------------------------------------------------------------------------
# Export to video.
export_to_video(output, "output.mp4", fps=15)
Logs
System Info
diffusers: 0.33.0.dev0
PyTorch: 2.7.0.dev20250305 (nightly)
OS: macOS on Apple Silicon with MPS enabled
Other libraries: torchaudio 2.6.0.dev20250305, torchvision 0.22.0.dev20250305
Device: MPS
Who can help?
No response