Skip to content

Commit 2668c67

Browse files
committed
backend='eager'
1 parent 859ddce commit 2668c67

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tutorials/developer_api_guide/tensor_parallel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def _(func, types, args, kwargs):
8080
args[1],
8181
None
8282
)
83-
print("input tensor shape:", input_tensor.shape)
84-
print("weight tensor shape:", weight_tensor.shape)
83+
print("mm input tensor shape:", input_tensor.shape)
84+
print("mm weight tensor shape:", weight_tensor.shape)
8585
weight_tensor = weight_tensor.dequantize()
8686
return aten.mm(input_tensor, weight_tensor)
8787

@@ -188,10 +188,10 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
188188
# [rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(128, 1024)), device_mesh=DeviceMesh('cuda', [0, 1,
189189
# 2, 3]), placements=(Replicate(),)), DTensor(local_tensor=MyDTypeTensorTP(data=FakeTensor(..., device='cuda:0', size=(128, 1024)), shape=torch.Size([1024, 1024]), device=cuda:0, dtype=torch.float32, requires_grad=False), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), None), **{}):
190190
# [rank0]: a and b must have same reduction dim, but got [128, 1024] X [128, 1024].
191-
c_up = torch.compile(d_up)
191+
c_up = torch.compile(d_up, backend="eager")
192192
y_up = c_up(input_dtensor)
193193
print("y_up:", y_up.shape)
194-
c_dn = torch.compile(d_dn)
194+
c_dn = torch.compile(d_dn, backend="eager")
195195
y_dn = c_dn(y_up)
196196
print("y_dn:", y_dn.shape)
197197
print("compiled result:", y_dn)

0 commit comments

Comments
 (0)