Skip to content

Commit 0043ace

Browse files
kwen2501weifengpy
authored andcommitted
[Distributed] Improve sharding example (pytorch#937)
* [Distributed] Improve sharding example * Add comment
1 parent fc6c393 commit 0043ace

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

tutorials/developer_api_guide/tensor_parallel.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import os
22
import torch
33
import torch.distributed as dist
4+
from typing import Sequence
45
from torch.distributed import DeviceMesh
5-
from torch.distributed._tensor import DTensor, Replicate, Shard
6+
from torch.distributed.tensor import DTensor, Replicate, Shard, Placement
67
from torch.utils._python_dispatch import return_and_correct_aliasing
78
from my_dtype_tensor_subclass import MyDTypeTensor, fill_defaults
89

@@ -101,18 +102,40 @@ def quantize(m: torch.nn.Module) -> torch.nn.Module:
101102
)
102103
return m
103104

105+
def shard(
106+
full_tensor: torch.Tensor,
107+
device_mesh: DeviceMesh,
108+
placements: Sequence[Placement],
109+
) -> DTensor:
110+
"""
111+
Add a shard function to simplify both colwise_shard and rowwise_shard. The
112+
shard function accepts a full tensor, and returns a DTensor based on
113+
indicated placements. Goal is to move the shard function as a static method
114+
of DTensor, e.g.
115+
dtensor = DTensor.shard(full_tensor, device_mesh, placement)
116+
"""
117+
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
118+
119+
shape, offset = compute_local_shape_and_global_offset(
120+
full_tensor.shape, device_mesh, placements
121+
)
122+
slices = [
123+
slice(cur_offset, cur_offset + cur_shape)
124+
for cur_shape, cur_offset in zip(shape, offset)
125+
]
126+
local_tensor = full_tensor[slices]
127+
return DTensor.from_local(
128+
local_tensor, device_mesh, placements
129+
)
130+
104131
def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
105132
"""
106133
Shard linear layer of the model in column-wise fashion
107134
"""
108135
# Column-wise is wrt to A^T, so for A it is row-wise.
109-
# Number of rows per rank
110136
orig_weight = m.linear.weight
111-
n_local_rows = orig_weight.size(0) // mesh.size()
112-
rank = mesh.get_local_rank()
113-
local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :]
114137
# Construct DTensor from local shard
115-
dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)])
138+
dtensor = shard(orig_weight, mesh, [Shard(0)])
116139
# Replace parameter in module
117140
m.linear.weight = torch.nn.Parameter(
118141
dtensor, requires_grad=False
@@ -124,13 +147,9 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
124147
Shard linear layer of the model in row-wise fashion
125148
"""
126149
# Row-wise is wrt to A^T, so for A it is column-wise.
127-
# Number of rows per rank
128150
orig_weight = m.linear.weight
129-
n_local_cols = orig_weight.size(1) // mesh.size()
130-
rank = mesh.get_local_rank()
131-
local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols]
132151
# Construct DTensor from local shard
133-
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)])
152+
dtensor = shard(orig_weight, mesh, [Shard(1)])
134153
# Replace parameter in module
135154
m.linear.weight = torch.nn.Parameter(
136155
dtensor, requires_grad=False

0 commit comments

Comments
 (0)