4242from torchvision import transforms
4343from transformers import CLIPTokenizer , CLIPTextModel , CLIPTextModelWithProjection
4444import transformers
45- from diffusers .optimization import SchedulerType , TYPE_TO_SCHEDULER_FUNCTION
45+ from diffusers .optimization import SchedulerType as DiffusersSchedulerType , TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION
46+ from transformers .optimization import SchedulerType , TYPE_TO_SCHEDULER_FUNCTION
4647from diffusers import (
4748 StableDiffusionPipeline ,
4849 DDPMScheduler ,
@@ -2972,6 +2973,20 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
29722973
29732974
29742975def add_optimizer_arguments (parser : argparse .ArgumentParser ):
2976+ def int_or_float (value ):
2977+ if value .endswith ('%' ):
2978+ try :
2979+ return float (value [:- 1 ]) / 100.0
2980+ except ValueError :
2981+ raise argparse .ArgumentTypeError (f"Value '{ value } ' is not a valid percentage" )
2982+ try :
2983+ float_value = float (value )
2984+ if float_value >= 1 :
2985+ return int (value )
2986+ return float (value )
2987+ except ValueError :
2988+ raise argparse .ArgumentTypeError (f"'{ value } ' is not an int or float" )
2989+
29752990 parser .add_argument (
29762991 "--optimizer_type" ,
29772992 type = str ,
@@ -3024,9 +3039,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
30243039 )
30253040 parser .add_argument (
30263041 "--lr_warmup_steps" ,
3027- type = int ,
3042+ type = int_or_float ,
3043+ default = 0 ,
3044+ help = "Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)" ,
3045+ )
3046+ parser .add_argument (
3047+ "--lr_decay_steps" ,
3048+ type = int_or_float ,
30283049 default = 0 ,
3029- help = "Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0) " ,
3050+ help = "Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps " ,
30303051 )
30313052 parser .add_argument (
30323053 "--lr_scheduler_num_cycles" ,
@@ -3046,6 +3067,18 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
30463067 help = "Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL"
30473068 + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効" ,
30483069 )
3070+ parser .add_argument (
3071+ "--lr_scheduler_timescale" ,
3072+ type = int ,
3073+ default = None ,
3074+ help = "Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" ,
3075+ )
3076+ parser .add_argument (
3077+ "--lr_scheduler_min_lr_ratio" ,
3078+ type = float ,
3079+ default = None ,
3080+ help = "The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" ,
3081+ )
30493082
30503083
30513084def add_training_arguments (parser : argparse .ArgumentParser , support_dreambooth : bool ):
@@ -4293,10 +4326,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
42934326 Unified API to get any scheduler from its name.
42944327 """
42954328 name = args .lr_scheduler
4296- num_warmup_steps : Optional [int ] = args .lr_warmup_steps
42974329 num_training_steps = args .max_train_steps * num_processes # * args.gradient_accumulation_steps
4330+ num_warmup_steps : Optional [int ] = int (args .lr_warmup_steps * num_training_steps ) if isinstance (args .lr_warmup_steps , float ) else args .lr_warmup_steps
4331+ num_decay_steps : Optional [int ] = int (args .lr_decay_steps * num_training_steps ) if isinstance (args .lr_decay_steps , float ) else args .lr_decay_steps
4332+ num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
42984333 num_cycles = args .lr_scheduler_num_cycles
42994334 power = args .lr_scheduler_power
4335+ timescale = args .lr_scheduler_timescale
4336+ min_lr_ratio = args .lr_scheduler_min_lr_ratio
43004337
43014338 lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
43024339 if args .lr_scheduler_args is not None and len (args .lr_scheduler_args ) > 0 :
@@ -4332,13 +4369,13 @@ def wrap_check_needless_num_warmup_steps(return_vals):
43324369 # logger.info(f"adafactor scheduler init lr {initial_lr}")
43334370 return wrap_check_needless_num_warmup_steps (transformers .optimization .AdafactorSchedule (optimizer , initial_lr ))
43344371
4335- name = SchedulerType (name )
4336- schedule_func = TYPE_TO_SCHEDULER_FUNCTION [name ]
4372+ name = SchedulerType (name ) or DiffusersSchedulerType ( name )
4373+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION [name ] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION [ name ]
43374374
43384375 if name == SchedulerType .CONSTANT :
43394376 return wrap_check_needless_num_warmup_steps (schedule_func (optimizer , ** lr_scheduler_kwargs ))
43404377
4341- if name == SchedulerType .PIECEWISE_CONSTANT :
4378+ if name == DiffusersSchedulerType .PIECEWISE_CONSTANT :
43424379 return schedule_func (optimizer , ** lr_scheduler_kwargs ) # step_rules and last_epoch are given as kwargs
43434380
43444381 # All other schedulers require `num_warmup_steps`
@@ -4348,6 +4385,9 @@ def wrap_check_needless_num_warmup_steps(return_vals):
43484385 if name == SchedulerType .CONSTANT_WITH_WARMUP :
43494386 return schedule_func (optimizer , num_warmup_steps = num_warmup_steps , ** lr_scheduler_kwargs )
43504387
4388+ if name == SchedulerType .INVERSE_SQRT :
4389+ return schedule_func (optimizer , num_warmup_steps = num_warmup_steps , timescale = timescale , ** lr_scheduler_kwargs )
4390+
43514391 # All other schedulers require `num_training_steps`
43524392 if num_training_steps is None :
43534393 raise ValueError (f"{ name } requires `num_training_steps`, please provide that argument." )
@@ -4366,7 +4406,31 @@ def wrap_check_needless_num_warmup_steps(return_vals):
43664406 optimizer , num_warmup_steps = num_warmup_steps , num_training_steps = num_training_steps , power = power , ** lr_scheduler_kwargs
43674407 )
43684408
4369- return schedule_func (optimizer , num_warmup_steps = num_warmup_steps , num_training_steps = num_training_steps , ** lr_scheduler_kwargs )
4409+ if name == SchedulerType .COSINE_WITH_MIN_LR :
4410+ return schedule_func (
4411+ optimizer ,
4412+ num_warmup_steps = num_warmup_steps ,
4413+ num_training_steps = num_training_steps ,
4414+ num_cycles = num_cycles / 2 ,
4415+ min_lr_rate = min_lr_ratio ,
4416+ ** lr_scheduler_kwargs ,
4417+ )
4418+
4419+ # All other schedulers require `num_decay_steps`
4420+ if num_decay_steps is None :
4421+ raise ValueError (f"{ name } requires `num_decay_steps`, please provide that argument." )
4422+ if name == SchedulerType .WARMUP_STABLE_DECAY :
4423+ return schedule_func (
4424+ optimizer ,
4425+ num_warmup_steps = num_warmup_steps ,
4426+ num_stable_steps = num_stable_steps ,
4427+ num_decay_steps = num_decay_steps ,
4428+ num_cycles = num_cycles / 2 ,
4429+ min_lr_ratio = min_lr_ratio if min_lr_ratio is not None else 0.0 ,
4430+ ** lr_scheduler_kwargs ,
4431+ )
4432+
4433+ return schedule_func (optimizer , num_warmup_steps = num_warmup_steps , num_training_steps = num_training_steps , num_decay_steps = num_decay_steps , ** lr_scheduler_kwargs )
43704434
43714435
43724436def prepare_dataset_args (args : argparse .Namespace , support_metadata : bool ):
0 commit comments