DDPShardedStrategy with gradient accumulation #13426
Unanswered
SerezD
asked this question in
DDP / multi-GPU / multi-node
Replies: 1 comment 4 replies
-
I haven't been able to reproduce this error. Which version of Lightning/FairScale are you using? Below is my code in which I haven't been able to reproduce this with: import os
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.strategies import DDPShardedStrategy
from torch.utils.data import DataLoader, Dataset
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
accelerator='gpu',
devices=2,
strategy=DDPShardedStrategy(),
accumulate_grad_batches=12,
num_sanity_val_steps=0,
max_epochs=1,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
if __name__ == "__main__":
run() |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I need to use both DDPShardedStrategy and accumulate_grad_batches > 1
This setting outputs the following warning during Training:
The question is: how can i remove the warning (using a no_sync() context) ?
Beta Was this translation helpful? Give feedback.
All reactions