Skip to content

Commit 96c677b

Browse files
committed
fix to work lienar/cosine lr scheduler closes #1602 ref #1393
1 parent be078bd commit 96c677b

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

library/train_util.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4707,6 +4707,15 @@ def wrap_check_needless_num_warmup_steps(return_vals):
47074707
**lr_scheduler_kwargs,
47084708
)
47094709

4710+
# these schedulers do not require `num_decay_steps`
4711+
if name == SchedulerType.LINEAR or name == SchedulerType.COSINE:
4712+
return schedule_func(
4713+
optimizer,
4714+
num_warmup_steps=num_warmup_steps,
4715+
num_training_steps=num_training_steps,
4716+
**lr_scheduler_kwargs,
4717+
)
4718+
47104719
# All other schedulers require `num_decay_steps`
47114720
if num_decay_steps is None:
47124721
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
@@ -5837,14 +5846,9 @@ def sample_image_inference(
58375846
wandb_tracker = accelerator.get_tracker("wandb")
58385847

58395848
import wandb
5849+
58405850
# not to commit images to avoid inconsistency between training and logging steps
5841-
wandb_tracker.log(
5842-
{f"sample_{i}": wandb.Image(
5843-
image,
5844-
caption=prompt # positive prompt as a caption
5845-
)},
5846-
commit=False
5847-
)
5851+
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
58485852

58495853

58505854
# endregion

0 commit comments

Comments
 (0)