Skip to content

Diffusers Transformer Pipeline Produces ComplexDouble Tensors on MPS, Causing Conversion Error #10986

Closed
@mozzipa

Description

@mozzipa

Describe the bug

When running the WanPipeline from diffusers on an MPS device, the pipeline fails with the error:

TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype.
Investigation indicates that in the transformer component (specifically in the rotary positional embedding function of WanRotaryPosEmbed), frequency tensors are computed using torch.float64 (and then converted to complex via torch.view_as_complex). This produces a ComplexDouble tensor (i.e. torch.complex128), which the MPS backend does not support.

Steps to Reproduce:

On an Apple Silicon Mac with MPS enabled, use diffusers (version 0.33.0.dev0) along with a recent PyTorch nightly (2.7.0.dev20250305).
Load a model pipeline as follows:

from diffusers import AutoencoderKLWan, WanPipeline
vae = AutoencoderKLWan.from_pretrained("<model_path>", subfolder="vae", torch_dtype=torch.float32).to("mps")
pipe = WanPipeline.from_pretrained("<model_path>", vae=vae, torch_dtype=torch.float32).to("mps")

Encode prompts and call the pipeline for inference.
The error occurs during the transformer’s forward pass in the rotary embedding function when it calls torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(...).

Expected Behavior:
When running on MPS, all computations (including the construction of rotary positional encodings) should use single-precision floats (and their corresponding complex type, i.e. torch.cfloat). In other words, the pipeline should ensure that no operations create or convert to ComplexDouble (complex128) tensors on MPS.

Workaround:
A temporary fix is to patch the helper functions used in computing the rotary embeddings so that they force the use of torch.float32. For instance, one workaround is to override the function (e.g., get_1d_rotary_pos_embed) that computes the frequency tensor so that it uses freqs_dtype=torch.float32 regardless of defaults. Additionally, patching WanRotaryPosEmbed.forward to cast its output to float32 (if it’s float64) avoids this error.

Reproduction

import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import torch

# ------------------------------------------------------------------------------
# 1. Patch torch.view_as_complex to avoid creating ComplexDouble on MPS.
_orig_view_as_complex = torch.view_as_complex

def patched_view_as_complex(tensor):
    if tensor.device.type == "mps" and tensor.dtype == torch.float64:
        tensor = tensor.to(torch.float32)
    return _orig_view_as_complex(tensor)

torch.view_as_complex = patched_view_as_complex

# ------------------------------------------------------------------------------
# 2. Patch get_1d_rotary_pos_embed so that it always computes frequencies as float32.
try:
    from diffusers.models.transformers.embeddings import get_1d_rotary_pos_embed as original_get_1d_rotary_pos_embed
    import diffusers.models.transformers.embeddings as embeddings_mod
except ImportError:
    from diffusers.models.embeddings import get_1d_rotary_pos_embed as original_get_1d_rotary_pos_embed
    import diffusers.models.embeddings as embeddings_mod

def patched_get_1d_rotary_pos_embed(dim, max_seq_len, theta, use_real, repeat_interleave_real, freqs_dtype):
    return original_get_1d_rotary_pos_embed(
        dim, max_seq_len, theta, use_real, repeat_interleave_real, freqs_dtype=torch.float32
    )

embeddings_mod.get_1d_rotary_pos_embed = patched_get_1d_rotary_pos_embed

# ------------------------------------------------------------------------------
# 3. Patch WanRotaryPosEmbed.forward to ensure its output is float32.
from diffusers.models.transformers.transformer_wan import WanRotaryPosEmbed
_orig_rope_forward = WanRotaryPosEmbed.forward

def patched_rope_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    result = _orig_rope_forward(self, hidden_states)
    if hidden_states.device.type == "mps" and result.dtype == torch.float64:
        result = result.to(torch.float32)
    return result

WanRotaryPosEmbed.forward = patched_rope_forward

# ------------------------------------------------------------------------------
# Model setup.
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.utils import export_to_video

model_id = "~/Wan2.1-T2V-1.3B-Diffusers"

vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
vae.to("mps")

pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.float32)
pipe.to("mps")

pipe.enable_attention_slicing()

# ------------------------------------------------------------------------------
# Define prompts.
prompt = "A cat walks on the grass, realistic"
negative_prompt = (
    "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, "
    "worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, "
    "deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
)

# ------------------------------------------------------------------------------
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
    prompt=prompt,
    negative_prompt=negative_prompt,
    do_classifier_free_guidance=True,
    num_videos_per_prompt=1,
    max_sequence_length=226,
    device="mps",
    dtype=torch.float32
)

# ------------------------------------------------------------------------------
# Generate video frames.
output = pipe(
    prompt_embeds=prompt_embeds,
    negative_prompt_embeds=negative_prompt_embeds,
    height=480,
    width=832,
    num_frames=81,
    guidance_scale=5.0
).frames[0]

# ------------------------------------------------------------------------------
# Export to video.
export_to_video(output, "output.mp4", fps=15)

Logs

System Info

diffusers: 0.33.0.dev0
PyTorch: 2.7.0.dev20250305 (nightly)
OS: macOS on Apple Silicon with MPS enabled
Other libraries: torchaudio 2.6.0.dev20250305, torchvision 0.22.0.dev20250305
Device: MPS

Who can help?

No response

Metadata

Metadata

Labels

bugSomething isn't workingstaleIssues that haven't received updates

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions