-
Notifications
You must be signed in to change notification settings - Fork 376
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
from typing import Optional, Tuple
from contextlib import nullcontext
import torch
import torch.nn as nn
import torch_tensorrt
class CosmosRotaryPosEmbed(nn.Module):
def __init__(
self,
hidden_size: int,
max_size: Tuple[int, int, int] = (128, 240, 240),
patch_size: Tuple[int, int, int] = (1, 2, 2),
base_fps: int = 24,
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
) -> None:
super().__init__()
self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
self.patch_size = patch_size
self.base_fps = base_fps
self.dim_h = hidden_size // 6 * 2
self.dim_w = hidden_size // 6 * 2
self.dim_t = hidden_size - self.dim_h - self.dim_w
self.h_ntk_factor = rope_scale[1] ** (self.dim_h / (self.dim_h - 2))
self.w_ntk_factor = rope_scale[2] ** (self.dim_w / (self.dim_w - 2))
self.t_ntk_factor = rope_scale[0] ** (self.dim_t / (self.dim_t - 2))
def forward(self, hidden_states: torch.Tensor, fps: Optional[int] = None, num_ranks: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
if num_ranks is not None:
pe_size[0] = pe_size[0] * num_ranks
device = hidden_states.device
h_theta = 10000.0 * self.h_ntk_factor
w_theta = 10000.0 * self.w_ntk_factor
t_theta = 10000.0 * self.t_ntk_factor
seq = torch.arange(max(self.max_size), device=device, dtype=torch.float32)
dim_h_range = (
torch.arange(0, self.dim_h, 2, device=device, dtype=torch.float32)[: (self.dim_h // 2)] / self.dim_h
)
dim_w_range = (
torch.arange(0, self.dim_w, 2, device=device, dtype=torch.float32)[: (self.dim_w // 2)] / self.dim_w
)
dim_t_range = (
torch.arange(0, self.dim_t, 2, device=device, dtype=torch.float32)[: (self.dim_t // 2)] / self.dim_t
)
h_spatial_freqs = 1.0 / (h_theta**dim_h_range)
w_spatial_freqs = 1.0 / (w_theta**dim_w_range)
temporal_freqs = 1.0 / (t_theta**dim_t_range)
# Use expand() instead of repeat() for torch_tensorrt compatibility
emb_h = torch.outer(seq[: pe_size[1]], h_spatial_freqs)[None, :, None, :].expand(pe_size[0], -1, pe_size[2], -1)
emb_w = torch.outer(seq[: pe_size[2]], w_spatial_freqs)[None, None, :, :].expand(pe_size[0], pe_size[1], -1, -1)
# Apply sequence scaling in temporal dimension
if fps is None:
# Images
emb_t = torch.outer(seq[: pe_size[0]], temporal_freqs)
else:
# Videos
emb_t = torch.outer(seq[: pe_size[0]] / fps * self.base_fps, temporal_freqs)
emb_t = emb_t[:, None, None, :].expand(-1, pe_size[1], pe_size[2], -1)
freqs = torch.cat([emb_t, emb_h, emb_w] * 2, dim=-1).flatten(0, 2).float()
cos = torch.cos(freqs)
sin = torch.sin(freqs)
if num_ranks is not None:
cos = cos.view(num_ranks, cos.shape[0] // num_ranks, *cos.shape[(1) :])
cos = cos[0]
sin = sin.view(num_ranks, sin.shape[0] // num_ranks, *sin.shape[1:])
sin = sin[0]
return cos, sin
def export_attention(model, hidden_states, fps, num_ranks):
with torch.no_grad():
# Only mark sequence length as dynamic, like run_llm.py does
# Don't mark batch dimension as dynamic to avoid constraint violations
seq_len = torch.export.Dim("seq_len", min=1, max=16)
print("Trying to export the model using torch.export.export()..")
# strict=False only enables autograd tracing and excludes dynamo.
# Use tuple format like export_llm - only mark sequence length (dim 1) as dynamic
ep = torch.export.export(
model,
args=(hidden_states, fps, num_ranks),
kwargs={},
dynamic_shapes=({2: seq_len}, None, None),
strict=False,
)
return ep
def compile_torchtrt(model, hidden_states, fps, num_ranks, min_block_size, debug):
ep = export_attention(model, hidden_states, fps, num_ranks)
# Set precision specific flags
use_fp32_acc = False
use_explicit_typing = False
enabled_precisions = {torch.bfloat16}
use_fp32_acc = False
with torch_tensorrt.logging.debug() if debug else nullcontext():
trt_model = torch_tensorrt.dynamo.compile(
ep,
inputs=[hidden_states, fps, num_ranks],
enabled_precisions=enabled_precisions,
# truncate_double=True,
use_explicit_typing=use_explicit_typing,
use_fp32_acc=use_fp32_acc,
disable_tf32=True,
use_python_runtime=True,
debug=debug,
offload_module_to_cpu=False,
min_block_size=min_block_size,
)
return trt_model
if __name__ == "__main__":
min_block_size = 1
attention_head_dim = 128
enable_pytorch_run = True
debug = False
device = "cuda"
# hidden_size = num_attention_heads * attention_head_dim
with torch.inference_mode():
model = CosmosRotaryPosEmbed(
hidden_size=attention_head_dim,
max_size=(128, 240, 240),
patch_size=(1, 2, 2),
rope_scale=(2.0, 1.0, 1.0),
).to(device)
# Convert model to the appropriate precision
model = model.to(torch.bfloat16)
input_dtype = torch.bfloat16
# Prepare input for benchmarking or evaluation
hidden_states = torch.randn(
1, 17, 8, 88, 160, dtype=input_dtype
).to(device)
fps = 30
num_ranks = 2
# Pyt
pyt_output_cos, pyt_output_sin = model(hidden_states, fps, num_ranks)
print("PyTorch output shape:", pyt_output_cos.shape, pyt_output_sin.shape)
print("Pytorch output:", pyt_output_cos.flatten(), pyt_output_sin.flatten())
# Compile the model with Torch-TensorRT
trt_model = compile_torchtrt(model, hidden_states, fps, num_ranks, min_block_size, debug)
# trt_model = torch.compile(
# model,
# backend="torch_tensorrt",
# options={
# "enabled_precisions": {input_dtype},
# "use_python_runtime": True,
# "min_block_size": min_block_size,
# },
# dynamic=None,
# )
trt_model = trt_model.to(device)
trt_output_cos, trt_output_sin = trt_model(hidden_states, fps, num_ranks)
print("TensorRT output shape:", trt_output_cos.shape, trt_output_sin.shape)
print("TensorRT output:", trt_output_cos.flatten(), trt_output_sin.flatten())
# Verify results match
diff_cos = (pyt_output_cos - trt_output_cos).abs().max().item()
diff_sin = (pyt_output_sin - trt_output_sin).abs().max().item()
print(f"Max difference between PyTorch and TRT: {diff_cos}, {diff_sin}")
# Check if results are close enough
tolerance = 0.01
if diff_cos < tolerance:
print(f"✅ Results match! (difference: {diff_cos} < {tolerance})")
else:
print(f"⚠️ Results differ! (difference: {diff_cos} >= {tolerance})")
if diff_sin < tolerance:
print(f"✅ Results match! (difference: {diff_sin} < {tolerance})")
else:
print(f"⚠️ Results differ! (difference: {diff_sin} >= {tolerance})")
Here is the result:
⚠️ Results differ! (difference: 0.11663024872541428 >= 0.01)
⚠️ Results differ! (difference: 0.1167231947183609 >= 0.01)
When I disable dynamic shapes, like commenting out dynamic_shapes=({2: seq_len}, None, None), the result can match.
To Reproduce
Steps to reproduce the behavior:
- Run Python script above
Expected behavior
The Torch-TRT output matches the Torch output.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 2.10.0.dev0
- PyTorch Version (e.g. 1.0): 2.9.0
- CPU Architecture: x86_64
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip,libtorch, source): pip - Build command you used (if compiling from source): PYTHON_ONLY=1 pip install -e .
- Are you using local sources or building from archives: No
- Python version: 3.10
- CUDA version: 12.9
- GPU models and configuration: Nvidia B200
- Any other relevant information: None
Additional context
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working