@@ -493,13 +493,15 @@ def train(self, args):
493493 # before resuming make hook for saving/loading to save/load the network weights only
494494 def save_model_hook (models , weights , output_dir ):
495495 # pop weights of other models than network to save only network weights
496- if accelerator .is_main_process :
496+ # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
497+ if accelerator .is_main_process or args .deepspeed :
497498 remove_indices = []
498499 for i , model in enumerate (models ):
499500 if not isinstance (model , type (accelerator .unwrap_model (network ))):
500501 remove_indices .append (i )
501502 for i in reversed (remove_indices ):
502- weights .pop (i )
503+ if len (weights ) > i :
504+ weights .pop (i )
503505 # print(f"save model hook: {len(weights)} weights will be saved")
504506
505507 # save current ecpoch and step
@@ -813,11 +815,12 @@ def load_model_hook(models, input_dir):
813815 )
814816 logger .info (f"skipping { initial_step } steps / { initial_step } ステップをスキップします" )
815817 initial_step *= args .gradient_accumulation_steps
818+
819+ # set epoch to start to make initial_step less than len(train_dataloader)
820+ epoch_to_start = initial_step // math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
816821 else :
817822 # if not, only epoch no is skipped for informative purpose
818- epoch_to_start = initial_step // math .ceil (
819- len (train_dataloader ) / args .gradient_accumulation_steps
820- )
823+ epoch_to_start = initial_step // math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
821824 initial_step = 0 # do not skip
822825
823826 global_step = 0
@@ -878,9 +881,11 @@ def remove_model(old_ckpt_name):
878881 self .sample_images (accelerator , args , 0 , global_step , accelerator .device , vae , tokenizer , text_encoder , unet )
879882
880883 # training loop
881- for skip_epoch in range (epoch_to_start ): # skip epochs
882- logger .info (f"skipping epoch { skip_epoch + 1 } because initial_step (multiplied) is { initial_step } " )
883- initial_step -= len (train_dataloader )
884+ if initial_step > 0 : # only if skip_until_initial_step is specified
885+ for skip_epoch in range (epoch_to_start ): # skip epochs
886+ logger .info (f"skipping epoch { skip_epoch + 1 } because initial_step (multiplied) is { initial_step } " )
887+ initial_step -= len (train_dataloader )
888+ global_step = initial_step
884889
885890 for epoch in range (epoch_to_start , num_train_epochs ):
886891 accelerator .print (f"\n epoch { epoch + 1 } /{ num_train_epochs } " )
@@ -892,7 +897,7 @@ def remove_model(old_ckpt_name):
892897
893898 skipped_dataloader = None
894899 if initial_step > 0 :
895- skipped_dataloader = accelerator .skip_first_batches (train_dataloader , initial_step - 1 )
900+ skipped_dataloader = accelerator .skip_first_batches (train_dataloader , initial_step - 1 )
896901 initial_step = 1
897902
898903 for step , batch in enumerate (skipped_dataloader or train_dataloader ):
0 commit comments