Skip to content

🐛 [Bug] Accuracy issue with dynamic shapes for rotary embeddings #3978

@zhaoyuanh

Description

@zhaoyuanh

Bug Description

from typing import Optional, Tuple
from contextlib import nullcontext

import torch
import torch.nn as nn
import torch_tensorrt


class CosmosRotaryPosEmbed(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        max_size: Tuple[int, int, int] = (128, 240, 240),
        patch_size: Tuple[int, int, int] = (1, 2, 2),
        base_fps: int = 24,
        rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
    ) -> None:
        super().__init__()

        self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
        self.patch_size = patch_size
        self.base_fps = base_fps

        self.dim_h = hidden_size // 6 * 2
        self.dim_w = hidden_size // 6 * 2
        self.dim_t = hidden_size - self.dim_h - self.dim_w

        self.h_ntk_factor = rope_scale[1] ** (self.dim_h / (self.dim_h - 2))
        self.w_ntk_factor = rope_scale[2] ** (self.dim_w / (self.dim_w - 2))
        self.t_ntk_factor = rope_scale[0] ** (self.dim_t / (self.dim_t - 2))

    def forward(self, hidden_states: torch.Tensor, fps: Optional[int] = None, num_ranks: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
        if num_ranks is not None:
            pe_size[0] = pe_size[0] * num_ranks
        device = hidden_states.device

        h_theta = 10000.0 * self.h_ntk_factor
        w_theta = 10000.0 * self.w_ntk_factor
        t_theta = 10000.0 * self.t_ntk_factor

        seq = torch.arange(max(self.max_size), device=device, dtype=torch.float32)
        dim_h_range = (
            torch.arange(0, self.dim_h, 2, device=device, dtype=torch.float32)[: (self.dim_h // 2)] / self.dim_h
        )
        dim_w_range = (
            torch.arange(0, self.dim_w, 2, device=device, dtype=torch.float32)[: (self.dim_w // 2)] / self.dim_w
        )
        dim_t_range = (
            torch.arange(0, self.dim_t, 2, device=device, dtype=torch.float32)[: (self.dim_t // 2)] / self.dim_t
        )
        h_spatial_freqs = 1.0 / (h_theta**dim_h_range)
        w_spatial_freqs = 1.0 / (w_theta**dim_w_range)
        temporal_freqs = 1.0 / (t_theta**dim_t_range)

        # Use expand() instead of repeat() for torch_tensorrt compatibility
        emb_h = torch.outer(seq[: pe_size[1]], h_spatial_freqs)[None, :, None, :].expand(pe_size[0], -1, pe_size[2], -1)
        emb_w = torch.outer(seq[: pe_size[2]], w_spatial_freqs)[None, None, :, :].expand(pe_size[0], pe_size[1], -1, -1)

        # Apply sequence scaling in temporal dimension
        if fps is None:
            # Images
            emb_t = torch.outer(seq[: pe_size[0]], temporal_freqs)
        else:
            # Videos
            emb_t = torch.outer(seq[: pe_size[0]] / fps * self.base_fps, temporal_freqs)

        emb_t = emb_t[:, None, None, :].expand(-1, pe_size[1], pe_size[2], -1)
        freqs = torch.cat([emb_t, emb_h, emb_w] * 2, dim=-1).flatten(0, 2).float()
        cos = torch.cos(freqs)
        sin = torch.sin(freqs)
        if num_ranks is not None:
            cos = cos.view(num_ranks, cos.shape[0] // num_ranks, *cos.shape[(1) :])
            cos = cos[0]
            sin = sin.view(num_ranks, sin.shape[0] // num_ranks, *sin.shape[1:])
            sin = sin[0]
        return cos, sin


def export_attention(model, hidden_states, fps, num_ranks):
    with torch.no_grad():
        # Only mark sequence length as dynamic, like run_llm.py does
        # Don't mark batch dimension as dynamic to avoid constraint violations
        seq_len = torch.export.Dim("seq_len", min=1, max=16)
        print("Trying to export the model using torch.export.export()..")
        # strict=False only enables autograd tracing and excludes dynamo.
        # Use tuple format like export_llm - only mark sequence length (dim 1) as dynamic
        ep = torch.export.export(
            model,
            args=(hidden_states, fps, num_ranks),
            kwargs={},
            dynamic_shapes=({2: seq_len}, None, None), 
            strict=False,
        )

    return ep


def compile_torchtrt(model, hidden_states, fps, num_ranks, min_block_size, debug):
    ep = export_attention(model, hidden_states, fps, num_ranks)
    # Set precision specific flags
    use_fp32_acc = False
    use_explicit_typing = False
    enabled_precisions = {torch.bfloat16}
    use_fp32_acc = False

    with torch_tensorrt.logging.debug() if debug else nullcontext():
        trt_model = torch_tensorrt.dynamo.compile(
            ep,
            inputs=[hidden_states, fps, num_ranks],
            enabled_precisions=enabled_precisions,
            # truncate_double=True,
            use_explicit_typing=use_explicit_typing,
            use_fp32_acc=use_fp32_acc,
            disable_tf32=True,
            use_python_runtime=True,
            debug=debug,
            offload_module_to_cpu=False,
            min_block_size=min_block_size,
        )

    return trt_model


if __name__ == "__main__":
    min_block_size = 1
    attention_head_dim = 128
    enable_pytorch_run = True
    debug = False
    device = "cuda"

    # hidden_size = num_attention_heads * attention_head_dim

    with torch.inference_mode():
        model = CosmosRotaryPosEmbed(
            hidden_size=attention_head_dim, 
            max_size=(128, 240, 240), 
            patch_size=(1, 2, 2), 
            rope_scale=(2.0, 1.0, 1.0),
        ).to(device)

        # Convert model to the appropriate precision
        model = model.to(torch.bfloat16)
        input_dtype = torch.bfloat16

        # Prepare input for benchmarking or evaluation
        hidden_states = torch.randn(
            1, 17, 8, 88, 160, dtype=input_dtype
        ).to(device)
        fps = 30
        num_ranks = 2

        # Pyt
        pyt_output_cos, pyt_output_sin = model(hidden_states, fps, num_ranks)
        print("PyTorch output shape:", pyt_output_cos.shape, pyt_output_sin.shape)
        print("Pytorch output:", pyt_output_cos.flatten(), pyt_output_sin.flatten())

        # Compile the model with Torch-TensorRT
        trt_model = compile_torchtrt(model, hidden_states, fps, num_ranks, min_block_size, debug)
        # trt_model = torch.compile(
        #     model,
        #     backend="torch_tensorrt",
        #     options={
        #         "enabled_precisions": {input_dtype},
        #         "use_python_runtime": True,
        #         "min_block_size": min_block_size,
        #     },
        #     dynamic=None,
        # )
        trt_model = trt_model.to(device)

        trt_output_cos, trt_output_sin = trt_model(hidden_states, fps, num_ranks)
        print("TensorRT output shape:", trt_output_cos.shape, trt_output_sin.shape)
        print("TensorRT output:", trt_output_cos.flatten(), trt_output_sin.flatten())
    
    # Verify results match
    diff_cos = (pyt_output_cos - trt_output_cos).abs().max().item()
    diff_sin = (pyt_output_sin - trt_output_sin).abs().max().item()
    print(f"Max difference between PyTorch and TRT: {diff_cos}, {diff_sin}")

    # Check if results are close enough
    tolerance = 0.01
    if diff_cos < tolerance:
        print(f"✅ Results match! (difference: {diff_cos} < {tolerance})")
    else:
        print(f"⚠️  Results differ! (difference: {diff_cos} >= {tolerance})")
    
    if diff_sin < tolerance:
        print(f"✅ Results match! (difference: {diff_sin} < {tolerance})")
    else:
        print(f"⚠️  Results differ! (difference: {diff_sin} >= {tolerance})")

Here is the result:

⚠️  Results differ! (difference: 0.11663024872541428 >= 0.01)
⚠️  Results differ! (difference: 0.1167231947183609 >= 0.01)

When I disable dynamic shapes, like commenting out dynamic_shapes=({2: seq_len}, None, None), the result can match.

To Reproduce

Steps to reproduce the behavior:

  1. Run Python script above

Expected behavior

The Torch-TRT output matches the Torch output.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.10.0.dev0
  • PyTorch Version (e.g. 1.0): 2.9.0
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): PYTHON_ONLY=1 pip install -e .
  • Are you using local sources or building from archives: No
  • Python version: 3.10
  • CUDA version: 12.9
  • GPU models and configuration: Nvidia B200
  • Any other relevant information: None

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions