Skip to content

🐛 [Bug] Shape mismatch when using repeat #3972

@zhaoyuanh

Description

@zhaoyuanh

Bug Description

from typing import Optional, Tuple
from contextlib import nullcontext

import torch
import torch.nn as nn
import torch_tensorrt


class CosmosLearnablePositionalEmbed(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        max_size: Tuple[int, int, int],
        patch_size: Tuple[int, int, int],
        eps: float = 1e-6,
    ) -> None:
        super().__init__()

        self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
        self.patch_size = patch_size
        self.eps = eps

        self.pos_emb_t = nn.Parameter(torch.zeros(self.max_size[0], hidden_size))
        self.pos_emb_h = nn.Parameter(torch.zeros(self.max_size[1], hidden_size))
        self.pos_emb_w = nn.Parameter(torch.zeros(self.max_size[2], hidden_size))
        
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]

        # Use expand() instead of repeat() - torch_tensorrt compatible
        # expand() creates a view without copying data, better for dynamic shapes
        # emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].expand(batch_size, -1, pe_size[1], pe_size[2], -1)
        # emb_h = self.pos_emb_h[: pe_size[1]][None, None, :, None, :].expand(batch_size, pe_size[0], -1, pe_size[2], -1)
        # emb_w = self.pos_emb_w[: pe_size[2]][None, None, None, :, :].expand(batch_size, pe_size[0], pe_size[1], -1, -1)
        emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].repeat(batch_size, 1, pe_size[1], pe_size[2], 1)
        emb_h = self.pos_emb_h[: pe_size[1]][None, None, :, None, :].repeat(batch_size, pe_size[0], 1, pe_size[2], 1)
        emb_w = self.pos_emb_w[: pe_size[2]][None, None, None, :, :].repeat(batch_size, pe_size[0], pe_size[1], 1, 1)
        emb = emb_t + emb_h + emb_w
        emb = emb.flatten(1, 3)

        norm = torch.linalg.vector_norm(emb, dim=-1, keepdim=True, dtype=torch.float32)
        # norm = torch.add(self.eps, norm, alpha=np.sqrt(norm.numel() / emb.numel()))
        # Use torch operations instead of np.sqrt to support dynamic shapes in torch.export
        # Compute the scale factor: sqrt(norm.numel() / emb.numel())
        alpha = (norm.numel() / emb.numel()) ** 0.5
        norm = torch.add(self.eps, norm, alpha=alpha)
        out = (emb / norm).type_as(hidden_states)
        return out


def export_attention(model, hidden_states):
    with torch.no_grad():
        # Only mark sequence length as dynamic, like run_llm.py does
        # Don't mark batch dimension as dynamic to avoid constraint violations
        seq_len = torch.export.Dim("seq_len", min=1, max=16)
        print("Trying to export the model using torch.export.export()..")
        # strict=False only enables autograd tracing and excludes dynamo.
        # Use tuple format like export_llm - only mark sequence length (dim 1) as dynamic
        ep = torch.export.export(
            model,
            args=(hidden_states,),
            kwargs={},
            dynamic_shapes=({2: seq_len},), 
            strict=False,
        )

    return ep


def compile_torchtrt(model, hidden_states, min_block_size, debug):
    ep = export_attention(model, hidden_states)
    # Set precision specific flags
    use_fp32_acc = False
    use_explicit_typing = False
    enabled_precisions = {torch.bfloat16}
    use_fp32_acc = False

    with torch_tensorrt.logging.debug() if debug else nullcontext():
        trt_model = torch_tensorrt.dynamo.compile(
            ep,
            inputs=[hidden_states],
            enabled_precisions=enabled_precisions,
            # truncate_double=True,
            use_explicit_typing=use_explicit_typing,
            use_fp32_acc=use_fp32_acc,
            disable_tf32=True,
            use_python_runtime=True,
            debug=debug,
            offload_module_to_cpu=False,
            min_block_size=min_block_size,
        )

    return trt_model


if __name__ == "__main__":
    precision = "BF16"
    min_block_size = 1
    batch_size = 1
    seq_len = 28160
    num_attention_heads = 32
    attention_head_dim = 128
    enable_pytorch_run = True
    debug = False
    device = "cuda"

    hidden_size = num_attention_heads * attention_head_dim

    with torch.inference_mode():
        model = CosmosLearnablePositionalEmbed(
            hidden_size=hidden_size,
            max_size=(128, 240, 240),
            patch_size=(1, 2, 2),
        ).to(device)

        # Convert model to the appropriate precision
        model = model.to(torch.bfloat16)
        input_dtype = torch.bfloat16

        # Prepare input for benchmarking or evaluation
        hidden_states = torch.randn(
            1, 17, 16, 88, 160, dtype=input_dtype
        ).to(device)

        # Pyt
        pyt_output = model(hidden_states)
        print("PyTorch output shape:", pyt_output.shape)
        print("Pytorch output:", pyt_output.flatten())

        # Compile the model with Torch-TensorRT
        trt_model = compile_torchtrt(model, hidden_states, min_block_size, debug)
        # trt_model = torch.compile(
        #     model,
        #     backend="torch_tensorrt",
        #     options={
        #         "enabled_precisions": {input_dtype},
        #         "use_python_runtime": True,
        #         "min_block_size": min_block_size,
        #     },
        #     dynamic=None,
        # )
        trt_model = trt_model.to(device)

        trt_output = trt_model(hidden_states)
        print("TensorRT output shape:", trt_output.shape)
        print("TensorRT output:", trt_output.flatten())
    
    # Verify results match
    diff = (pyt_output - trt_output).abs().max().item()
    print(f"Max difference between PyTorch and TRT: {diff}")

    # Check if results are close enough
    tolerance = 0.01
    if diff < tolerance:
        print(f"✅ Results match! (difference: {diff} < {tolerance})")
    else:
        print(f"⚠️  Results differ! (difference: {diff} >= {tolerance})")

Here is the error message.

Traceback (most recent call last):
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 214, in <module>
    trt_model = compile_torchtrt(model, hidden_states, min_block_size, debug)
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 153, in compile_torchtrt
    trt_model = torch_tensorrt.dynamo.compile(
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/_compiler.py", line 782, in compile
    trt_gm = compile_module(
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/_compiler.py", line 1028, in compile_module
    trt_module = convert_module(
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 145, in convert_module
    serialized_interpreter_result = interpret_module_to_result(
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 78, in interpret_module_to_result
    interpreter_result = interpreter.run()
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 721, in run
    self._construct_trt_network_def()
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 433, in _construct_trt_network_def
    super().run()
  File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 174, in run
    self.env[node] = self.run_node(node)
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 790, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
  File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 256, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 897, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/converter_utils.py", line 677, in convert_with_type_enforcement
    return func(ctx, target, new_args, new_kwargs, name)
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 1236, in aten_ops_expand
    return impl.slice.expand(
  File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py", line 245, in expand
    assert len(input_t.shape) == shape_rank
ValueError: __len__() should return >= 0

While executing %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_2, [1, 1, 44, 80, 1, 1, %sym_size_int_6, 1, 1, 4096]), kwargs = {})
Original traceback:
File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 109, in forward
    emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].repeat(batch_size, 1, pe_size[1], pe_size[2], 1)
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

To Reproduce

Steps to reproduce the behavior:

  1. Run Python script above

Expected behavior

Passed and the Torch-TRT output matches the Torch output.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.10.0.dev0
  • PyTorch Version (e.g. 1.0): 2.9.0
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): PYTHON_ONLY=1 pip install -e .
  • Are you using local sources or building from archives: No
  • Python version: 3.10
  • CUDA version: 12.9
  • GPU models and configuration: Nvidia B200
  • Any other relevant information: None

Additional context

The issue can resolved by replacing repeat with expand.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions