-
Notifications
You must be signed in to change notification settings - Fork 376
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
from typing import Optional, Tuple
from contextlib import nullcontext
import torch
import torch.nn as nn
import torch_tensorrt
class CosmosLearnablePositionalEmbed(nn.Module):
def __init__(
self,
hidden_size: int,
max_size: Tuple[int, int, int],
patch_size: Tuple[int, int, int],
eps: float = 1e-6,
) -> None:
super().__init__()
self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
self.patch_size = patch_size
self.eps = eps
self.pos_emb_t = nn.Parameter(torch.zeros(self.max_size[0], hidden_size))
self.pos_emb_h = nn.Parameter(torch.zeros(self.max_size[1], hidden_size))
self.pos_emb_w = nn.Parameter(torch.zeros(self.max_size[2], hidden_size))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
# Use expand() instead of repeat() - torch_tensorrt compatible
# expand() creates a view without copying data, better for dynamic shapes
# emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].expand(batch_size, -1, pe_size[1], pe_size[2], -1)
# emb_h = self.pos_emb_h[: pe_size[1]][None, None, :, None, :].expand(batch_size, pe_size[0], -1, pe_size[2], -1)
# emb_w = self.pos_emb_w[: pe_size[2]][None, None, None, :, :].expand(batch_size, pe_size[0], pe_size[1], -1, -1)
emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].repeat(batch_size, 1, pe_size[1], pe_size[2], 1)
emb_h = self.pos_emb_h[: pe_size[1]][None, None, :, None, :].repeat(batch_size, pe_size[0], 1, pe_size[2], 1)
emb_w = self.pos_emb_w[: pe_size[2]][None, None, None, :, :].repeat(batch_size, pe_size[0], pe_size[1], 1, 1)
emb = emb_t + emb_h + emb_w
emb = emb.flatten(1, 3)
norm = torch.linalg.vector_norm(emb, dim=-1, keepdim=True, dtype=torch.float32)
# norm = torch.add(self.eps, norm, alpha=np.sqrt(norm.numel() / emb.numel()))
# Use torch operations instead of np.sqrt to support dynamic shapes in torch.export
# Compute the scale factor: sqrt(norm.numel() / emb.numel())
alpha = (norm.numel() / emb.numel()) ** 0.5
norm = torch.add(self.eps, norm, alpha=alpha)
out = (emb / norm).type_as(hidden_states)
return out
def export_attention(model, hidden_states):
with torch.no_grad():
# Only mark sequence length as dynamic, like run_llm.py does
# Don't mark batch dimension as dynamic to avoid constraint violations
seq_len = torch.export.Dim("seq_len", min=1, max=16)
print("Trying to export the model using torch.export.export()..")
# strict=False only enables autograd tracing and excludes dynamo.
# Use tuple format like export_llm - only mark sequence length (dim 1) as dynamic
ep = torch.export.export(
model,
args=(hidden_states,),
kwargs={},
dynamic_shapes=({2: seq_len},),
strict=False,
)
return ep
def compile_torchtrt(model, hidden_states, min_block_size, debug):
ep = export_attention(model, hidden_states)
# Set precision specific flags
use_fp32_acc = False
use_explicit_typing = False
enabled_precisions = {torch.bfloat16}
use_fp32_acc = False
with torch_tensorrt.logging.debug() if debug else nullcontext():
trt_model = torch_tensorrt.dynamo.compile(
ep,
inputs=[hidden_states],
enabled_precisions=enabled_precisions,
# truncate_double=True,
use_explicit_typing=use_explicit_typing,
use_fp32_acc=use_fp32_acc,
disable_tf32=True,
use_python_runtime=True,
debug=debug,
offload_module_to_cpu=False,
min_block_size=min_block_size,
)
return trt_model
if __name__ == "__main__":
precision = "BF16"
min_block_size = 1
batch_size = 1
seq_len = 28160
num_attention_heads = 32
attention_head_dim = 128
enable_pytorch_run = True
debug = False
device = "cuda"
hidden_size = num_attention_heads * attention_head_dim
with torch.inference_mode():
model = CosmosLearnablePositionalEmbed(
hidden_size=hidden_size,
max_size=(128, 240, 240),
patch_size=(1, 2, 2),
).to(device)
# Convert model to the appropriate precision
model = model.to(torch.bfloat16)
input_dtype = torch.bfloat16
# Prepare input for benchmarking or evaluation
hidden_states = torch.randn(
1, 17, 16, 88, 160, dtype=input_dtype
).to(device)
# Pyt
pyt_output = model(hidden_states)
print("PyTorch output shape:", pyt_output.shape)
print("Pytorch output:", pyt_output.flatten())
# Compile the model with Torch-TensorRT
trt_model = compile_torchtrt(model, hidden_states, min_block_size, debug)
# trt_model = torch.compile(
# model,
# backend="torch_tensorrt",
# options={
# "enabled_precisions": {input_dtype},
# "use_python_runtime": True,
# "min_block_size": min_block_size,
# },
# dynamic=None,
# )
trt_model = trt_model.to(device)
trt_output = trt_model(hidden_states)
print("TensorRT output shape:", trt_output.shape)
print("TensorRT output:", trt_output.flatten())
# Verify results match
diff = (pyt_output - trt_output).abs().max().item()
print(f"Max difference between PyTorch and TRT: {diff}")
# Check if results are close enough
tolerance = 0.01
if diff < tolerance:
print(f"✅ Results match! (difference: {diff} < {tolerance})")
else:
print(f"⚠️ Results differ! (difference: {diff} >= {tolerance})")
Here is the error message.
Traceback (most recent call last):
File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 214, in <module>
trt_model = compile_torchtrt(model, hidden_states, min_block_size, debug)
File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 153, in compile_torchtrt
trt_model = torch_tensorrt.dynamo.compile(
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/_compiler.py", line 782, in compile
trt_gm = compile_module(
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/_compiler.py", line 1028, in compile_module
trt_module = convert_module(
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 145, in convert_module
serialized_interpreter_result = interpret_module_to_result(
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 78, in interpret_module_to_result
interpreter_result = interpreter.run()
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 721, in run
self._construct_trt_network_def()
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 433, in _construct_trt_network_def
super().run()
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 174, in run
self.env[node] = self.run_node(node)
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 790, in run_node
trt_node: torch.fx.Node = super().run_node(n)
File "/home/zhaoyuanh/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 256, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 897, in call_function
return converter(self.ctx, target, args, kwargs, self._cur_node_name)
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/converter_utils.py", line 677, in convert_with_type_enforcement
return func(ctx, target, new_args, new_kwargs, name)
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 1236, in aten_ops_expand
return impl.slice.expand(
File "/home/scratch.zhaoyuanh_coreai/torch-trt/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py", line 245, in expand
assert len(input_t.shape) == shape_rank
ValueError: __len__() should return >= 0
While executing %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_2, [1, 1, 44, 80, 1, 1, %sym_size_int_6, 1, 1, 4096]), kwargs = {})
Original traceback:
File "/home/scratch.zhaoyuanh_coreai/torch-trt/tools/llm/minimal_reproducer.py", line 109, in forward
emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].repeat(batch_size, 1, pe_size[1], pe_size[2], 1)
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
To Reproduce
Steps to reproduce the behavior:
- Run Python script above
Expected behavior
Passed and the Torch-TRT output matches the Torch output.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 2.10.0.dev0
- PyTorch Version (e.g. 1.0): 2.9.0
- CPU Architecture: x86_64
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip,libtorch, source): pip - Build command you used (if compiling from source): PYTHON_ONLY=1 pip install -e .
- Are you using local sources or building from archives: No
- Python version: 3.10
- CUDA version: 12.9
- GPU models and configuration: Nvidia B200
- Any other relevant information: None
Additional context
The issue can resolved by replacing repeat with expand.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working