Skip to content

Commit 5a4ca92

Browse files
authored
Adding torch accelerator and requirements file to FSDP2 example (#1375)
Signed-off-by: dggaytan <[email protected]>
1 parent e9a4e75 commit 5a4ca92

File tree

5 files changed

+46
-8
lines changed

5 files changed

+46
-8
lines changed

distributed/FSDP2/README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
11
## FSDP2
22
To run FSDP2 on transformer model:
3+
34
```
45
cd distributed/FSDP2
5-
torchrun --nproc_per_node 2 train.py
6+
pip install -r requirements.txt
7+
torchrun --nproc_per_node 2 example.py
68
```
79
* For 1st time, it creates a "checkpoints" folder and saves state dicts there
810
* For 2nd time, it loads from previous checkpoints
911

1012
To enable explicit prefetching
1113
```
12-
torchrun --nproc_per_node 2 train.py --explicit-prefetch
14+
torchrun --nproc_per_node 2 example.py --explicit-prefetch
1315
```
1416

1517
To enable mixed precision
1618
```
17-
torchrun --nproc_per_node 2 train.py --mixed-precision
19+
torchrun --nproc_per_node 2 example.py --mixed-precision
1820
```
1921

2022
To showcase DCP API
2123
```
22-
torchrun --nproc_per_node 2 train.py --dcp-api
24+
torchrun --nproc_per_node 2 example.py --dcp-api
2325
```
2426

2527
## Ensure you are running a recent version of PyTorch:

distributed/FSDP2/train.py renamed to distributed/FSDP2/example.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
88
from utils import inspect_mixed_precision, inspect_model
99

10+
def verify_min_gpu_count(min_gpus: int = 2) -> bool:
11+
""" verification that we have at least 2 gpus to run dist examples """
12+
has_gpu = torch.accelerator.is_available()
13+
gpu_count = torch.accelerator.device_count()
14+
return has_gpu and gpu_count >= min_gpus
1015

1116
def set_modules_to_forward_prefetch(model, num_to_forward_prefetch):
1217
for i, layer in enumerate(model.layers):
@@ -29,10 +34,23 @@ def set_modules_to_backward_prefetch(model, num_to_backward_prefetch):
2934

3035

3136
def main(args):
37+
_min_gpu_count = 2
38+
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
39+
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
40+
exit()
3241
rank = int(os.environ["LOCAL_RANK"])
33-
device = torch.device(f"cuda:{rank}")
34-
torch.cuda.set_device(device)
35-
torch.distributed.init_process_group(backend="nccl", device_id=device)
42+
if torch.accelerator.is_available():
43+
device_type = torch.accelerator.current_accelerator()
44+
device = torch.device(f"{device_type}:{rank}")
45+
torch.accelerator.device_index(rank)
46+
print(f"Running on rank {rank} on device {device}")
47+
else:
48+
device = torch.device("cpu")
49+
print(f"Running on device {device}")
50+
51+
backend = torch.distributed.get_default_backend_for_device(device)
52+
torch.distributed.init_process_group(backend=backend, device_id=device)
53+
3654
torch.manual_seed(0)
3755
vocab_size = 1024
3856
batch_size = 32
@@ -64,7 +82,7 @@ def main(args):
6482

6583
checkpointer = Checkpointer("checkpoints", dcp_api=args.dcp_api)
6684
if checkpointer.last_training_time is None:
67-
model.to_empty(device="cuda")
85+
model.to_empty(device=device)
6886
model.reset_parameters()
6987
else:
7088
checkpointer.load_model(model)
@@ -96,4 +114,5 @@ def main(args):
96114
parser.add_argument("--mixed-precision", action="store_true", default=False)
97115
parser.add_argument("--dcp-api", action="store_true", default=False)
98116
args = parser.parse_args()
117+
99118
main(args)

distributed/FSDP2/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch>=2.7
2+
numpy

distributed/FSDP2/run_example.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# /bin/bash
2+
# bash run_example.sh {file_to_run.py} {num_gpus}
3+
# where file_to_run = example to run. Default = 'example.py'
4+
# num_gpus = num local gpus to use (must be at least 2). Default = 4
5+
6+
# samples to run include:
7+
# example.py
8+
9+
echo "Launching ${1:-example.py} with ${2:-4} gpus"
10+
torchrun --nnodes=1 --nproc_per_node=${2:-4} ${1:-example.py}
11+

run_distributed_examples.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ function distributed_tensor_parallelism() {
5050
uv run bash run_example.sh fsdp_tp_example.py || error "2D parallel example failed"
5151
}
5252

53+
function distributed_FSDP2() {
54+
uv run bash run_example.sh example.py || error "FSDP2 example failed"
55+
}
56+
5357
function distributed_ddp() {
5458
uv run bash run_example.sh example.py || error "ddp example failed"
5559
}

0 commit comments

Comments
 (0)