1
1
import os
2
2
import torch
3
3
import torch .distributed as dist
4
+ from typing import Sequence
4
5
from torch .distributed import DeviceMesh
5
- from torch .distributed ._tensor import DTensor , Replicate , Shard
6
+ from torch .distributed .tensor import DTensor , Replicate , Shard , Placement
6
7
from torch .utils ._python_dispatch import return_and_correct_aliasing
7
8
from my_dtype_tensor_subclass import MyDTypeTensor , fill_defaults
8
9
@@ -101,18 +102,40 @@ def quantize(m: torch.nn.Module) -> torch.nn.Module:
101
102
)
102
103
return m
103
104
105
+ def shard (
106
+ full_tensor : torch .Tensor ,
107
+ device_mesh : DeviceMesh ,
108
+ placements : Sequence [Placement ],
109
+ ) -> DTensor :
110
+ """
111
+ Add a shard function to simplify both colwise_shard and rowwise_shard. The
112
+ shard function accepts a full tensor, and returns a DTensor based on
113
+ indicated placements. Goal is to move the shard function as a static method
114
+ of DTensor, e.g.
115
+ dtensor = DTensor.shard(full_tensor, device_mesh, placement)
116
+ """
117
+ from torch .distributed .tensor ._utils import compute_local_shape_and_global_offset
118
+
119
+ shape , offset = compute_local_shape_and_global_offset (
120
+ full_tensor .shape , device_mesh , placements
121
+ )
122
+ slices = [
123
+ slice (cur_offset , cur_offset + cur_shape )
124
+ for cur_shape , cur_offset in zip (shape , offset )
125
+ ]
126
+ local_tensor = full_tensor [slices ]
127
+ return DTensor .from_local (
128
+ local_tensor , device_mesh , placements
129
+ )
130
+
104
131
def colwise_shard (m : torch .nn .Module , mesh : DeviceMesh ) -> torch .nn .Module :
105
132
"""
106
133
Shard linear layer of the model in column-wise fashion
107
134
"""
108
135
# Column-wise is wrt to A^T, so for A it is row-wise.
109
- # Number of rows per rank
110
136
orig_weight = m .linear .weight
111
- n_local_rows = orig_weight .size (0 ) // mesh .size ()
112
- rank = mesh .get_local_rank ()
113
- local_shard = orig_weight [rank * n_local_rows : (rank + 1 ) * n_local_rows , :]
114
137
# Construct DTensor from local shard
115
- dtensor = DTensor . from_local ( local_shard , mesh , [Shard (0 )])
138
+ dtensor = shard ( orig_weight , mesh , [Shard (0 )])
116
139
# Replace parameter in module
117
140
m .linear .weight = torch .nn .Parameter (
118
141
dtensor , requires_grad = False
@@ -124,13 +147,9 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
124
147
Shard linear layer of the model in row-wise fashion
125
148
"""
126
149
# Row-wise is wrt to A^T, so for A it is column-wise.
127
- # Number of rows per rank
128
150
orig_weight = m .linear .weight
129
- n_local_cols = orig_weight .size (1 ) // mesh .size ()
130
- rank = mesh .get_local_rank ()
131
- local_shard = orig_weight [:, rank * n_local_cols : (rank + 1 ) * n_local_cols ]
132
151
# Construct DTensor from local shard
133
- dtensor = DTensor . from_local ( local_shard , mesh , [Shard (1 )])
152
+ dtensor = shard ( orig_weight , mesh , [Shard (1 )])
134
153
# Replace parameter in module
135
154
m .linear .weight = torch .nn .Parameter (
136
155
dtensor , requires_grad = False
0 commit comments