@@ -481,6 +481,26 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
481481 text_encoder2 = accelerator .prepare (text_encoder2 )
482482 optimizer , train_dataloader , lr_scheduler = accelerator .prepare (optimizer , train_dataloader , lr_scheduler )
483483
484+ # TextEncoderの出力をキャッシュするときにはCPUへ移動する
485+ if args .cache_text_encoder_outputs :
486+ # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
487+ text_encoder1 .to ("cpu" , dtype = torch .float32 )
488+ text_encoder2 .to ("cpu" , dtype = torch .float32 )
489+ clean_memory_on_device (accelerator .device )
490+ else :
491+ # make sure Text Encoders are on GPU
492+ text_encoder1 .to (accelerator .device )
493+ text_encoder2 .to (accelerator .device )
494+
495+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
496+ if args .full_fp16 :
497+ # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
498+ # -> But we think it's ok to patch accelerator even if deepspeed is enabled.
499+ train_util .patch_accelerator_for_fp16_training (accelerator )
500+
501+ # resumeする
502+ train_util .resume_from_local_or_hf_if_specified (accelerator , args )
503+
484504 if args .fused_backward_pass :
485505 # use fused optimizer for backward pass: other optimizers will be supported in the future
486506 import library .adafactor_fused
@@ -532,26 +552,6 @@ def optimizer_hook(parameter: torch.Tensor):
532552 parameter_optimizer_map [parameter ] = opt_idx
533553 num_parameters_per_group [opt_idx ] += 1
534554
535- # TextEncoderの出力をキャッシュするときにはCPUへ移動する
536- if args .cache_text_encoder_outputs :
537- # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
538- text_encoder1 .to ("cpu" , dtype = torch .float32 )
539- text_encoder2 .to ("cpu" , dtype = torch .float32 )
540- clean_memory_on_device (accelerator .device )
541- else :
542- # make sure Text Encoders are on GPU
543- text_encoder1 .to (accelerator .device )
544- text_encoder2 .to (accelerator .device )
545-
546- # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
547- if args .full_fp16 :
548- # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
549- # -> But we think it's ok to patch accelerator even if deepspeed is enabled.
550- train_util .patch_accelerator_for_fp16_training (accelerator )
551-
552- # resumeする
553- train_util .resume_from_local_or_hf_if_specified (accelerator , args )
554-
555555 # epoch数を計算する
556556 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
557557 num_train_epochs = math .ceil (args .max_train_steps / num_update_steps_per_epoch )
@@ -589,7 +589,11 @@ def optimizer_hook(parameter: torch.Tensor):
589589 init_kwargs ["wandb" ] = {"name" : args .wandb_run_name }
590590 if args .log_tracker_config is not None :
591591 init_kwargs = toml .load (args .log_tracker_config )
592- accelerator .init_trackers ("finetuning" if args .log_tracker_name is None else args .log_tracker_name , config = train_util .get_sanitized_config_or_none (args ), init_kwargs = init_kwargs )
592+ accelerator .init_trackers (
593+ "finetuning" if args .log_tracker_name is None else args .log_tracker_name ,
594+ config = train_util .get_sanitized_config_or_none (args ),
595+ init_kwargs = init_kwargs ,
596+ )
593597
594598 # For --sample_at_first
595599 sdxl_train_util .sample_images (
0 commit comments