Skip to content

Commit dc0441a

Browse files
authored
fix: Epoch checkpointing index off by 1 (#2911)
1 parent b942406 commit dc0441a

File tree

4 files changed

+11
-8
lines changed

4 files changed

+11
-8
lines changed

recipes/full_finetune_distributed.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -915,8 +915,9 @@ def validate(self) -> dict[str, float]:
915915
return log_dict
916916

917917
def save_checkpoint(self, *, epoch: int, full_tensors: bool):
918+
training_progress_epoch = epoch
918919
if self.global_step % self._steps_per_epoch == 0:
919-
epoch += 1
920+
training_progress_epoch += 1
920921

921922
self._checkpoint_client.save_checkpoint(
922923
model=self._model,
@@ -927,7 +928,7 @@ def save_checkpoint(self, *, epoch: int, full_tensors: bool):
927928
),
928929
training_progress=TrainingProgress(
929930
seed=self.seed,
930-
epochs_run=epoch,
931+
epochs_run=training_progress_epoch,
931932
total_epochs=self.total_epochs,
932933
max_steps_per_epoch=self.max_steps_per_epoch,
933934
steps_run=self.global_step,

recipes/full_finetune_single_device.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,14 +526,16 @@ def save_checkpoint(self, *, epoch: int, step: int, full_tensors: bool) -> None:
526526
correctly creating the checkpoint dict and passing to the checkpointer.
527527
"""
528528
# Since we might save at an epoch boundary, we need to increment the epoch counter
529+
training_progress_epoch = epoch
529530
if step % self._steps_per_epoch == 0:
530-
epoch += 1
531+
training_progress_epoch += 1
532+
531533
self._checkpoint_client.save_checkpoint(
532534
model=self._model,
533535
optimizer=self.optimizer,
534536
training_progress=TrainingProgress(
535537
seed=self.seed,
536-
epochs_run=epoch,
538+
epochs_run=training_progress_epoch,
537539
total_epochs=self.total_epochs,
538540
max_steps_per_epoch=self.max_steps_per_epoch,
539541
dataloader_state_dict=self._dataloader.state_dict(),

tests/recipes/test_full_finetune_distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def test_training_state_on_resume_from_distributed_checkpoint_single_rank(
397397

398398
resumed_log_dir = (tmpdir / "resumed/").mkdir()
399399
resumed_log_file = gen_log_file_name(resumed_log_dir)
400-
shutil.rmtree((tmpdir / "epoch_2"))
400+
shutil.rmtree((tmpdir / "epoch_1"))
401401

402402
# Resume training
403403
cmd_2 = f"""
@@ -504,7 +504,7 @@ def test_training_state_on_resume_from_distributed_checkpoint_multi_rank(
504504

505505
resumed_log_dir = (tmpdir / "resumed/").mkdir()
506506
resumed_log_file = gen_log_file_name(resumed_log_dir)
507-
shutil.rmtree((tmpdir / "epoch_2"))
507+
shutil.rmtree((tmpdir / "epoch_1"))
508508

509509
# Resume training
510510
cmd_2 = f"""

tests/recipes/test_full_finetune_single_device.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ def test_training_state_on_resume(self, tmpdir, use_steps, monkeypatch):
204204
final_ckpt_dir = "step_4"
205205
prev_ckpt_dir = "step_2"
206206
else:
207-
final_ckpt_dir = "epoch_2"
208-
prev_ckpt_dir = "epoch_1"
207+
final_ckpt_dir = "epoch_1"
208+
prev_ckpt_dir = "epoch_0"
209209
cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
210210

211211
monkeypatch.setattr(sys, "argv", cmd_1)

0 commit comments

Comments
 (0)