1515import torch
1616import re
1717from library .utils import setup_logging
18-
1918setup_logging ()
2019import logging
21-
2220logger = logging .getLogger (__name__ )
2321
2422RE_UPDOWN = re .compile (r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_" )
@@ -506,15 +504,6 @@ def create_network(
506504 if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None :
507505 network .set_block_lr_weight (up_lr_weight , mid_lr_weight , down_lr_weight )
508506
509- loraplus_lr_ratio = kwargs .get ("loraplus_lr_ratio" , None )
510- loraplus_unet_lr_ratio = kwargs .get ("loraplus_unet_lr_ratio" , None )
511- loraplus_text_encoder_lr_ratio = kwargs .get ("loraplus_text_encoder_lr_ratio" , None )
512- loraplus_lr_ratio = float (loraplus_lr_ratio ) if loraplus_lr_ratio is not None else None
513- loraplus_unet_lr_ratio = float (loraplus_unet_lr_ratio ) if loraplus_unet_lr_ratio is not None else None
514- loraplus_text_encoder_lr_ratio = float (loraplus_text_encoder_lr_ratio ) if loraplus_text_encoder_lr_ratio is not None else None
515- if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None :
516- network .set_loraplus_lr_ratio (loraplus_lr_ratio , loraplus_unet_lr_ratio , loraplus_text_encoder_lr_ratio )
517-
518507 return network
519508
520509
@@ -540,9 +529,7 @@ def parse_floats(s):
540529 len (block_dims ) == num_total_blocks
541530 ), f"block_dims must have { num_total_blocks } elements / block_dimsは{ num_total_blocks } 個指定してください"
542531 else :
543- logger .warning (
544- f"block_dims is not specified. all dims are set to { network_dim } / block_dimsが指定されていません。すべてのdimは{ network_dim } になります"
545- )
532+ logger .warning (f"block_dims is not specified. all dims are set to { network_dim } / block_dimsが指定されていません。すべてのdimは{ network_dim } になります" )
546533 block_dims = [network_dim ] * num_total_blocks
547534
548535 if block_alphas is not None :
@@ -816,31 +803,21 @@ def __init__(
816803 self .rank_dropout = rank_dropout
817804 self .module_dropout = module_dropout
818805
819- self .loraplus_lr_ratio = None
820- self .loraplus_unet_lr_ratio = None
821- self .loraplus_text_encoder_lr_ratio = None
822-
823806 if modules_dim is not None :
824807 logger .info (f"create LoRA network from weights" )
825808 elif block_dims is not None :
826809 logger .info (f"create LoRA network from block_dims" )
827- logger .info (
828- f"neuron dropout: p={ self .dropout } , rank dropout: p={ self .rank_dropout } , module dropout: p={ self .module_dropout } "
829- )
810+ logger .info (f"neuron dropout: p={ self .dropout } , rank dropout: p={ self .rank_dropout } , module dropout: p={ self .module_dropout } " )
830811 logger .info (f"block_dims: { block_dims } " )
831812 logger .info (f"block_alphas: { block_alphas } " )
832813 if conv_block_dims is not None :
833814 logger .info (f"conv_block_dims: { conv_block_dims } " )
834815 logger .info (f"conv_block_alphas: { conv_block_alphas } " )
835816 else :
836817 logger .info (f"create LoRA network. base dim (rank): { lora_dim } , alpha: { alpha } " )
837- logger .info (
838- f"neuron dropout: p={ self .dropout } , rank dropout: p={ self .rank_dropout } , module dropout: p={ self .module_dropout } "
839- )
818+ logger .info (f"neuron dropout: p={ self .dropout } , rank dropout: p={ self .rank_dropout } , module dropout: p={ self .module_dropout } " )
840819 if self .conv_lora_dim is not None :
841- logger .info (
842- f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): { self .conv_lora_dim } , alpha: { self .conv_alpha } "
843- )
820+ logger .info (f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): { self .conv_lora_dim } , alpha: { self .conv_alpha } " )
844821
845822 # create module instances
846823 def create_modules (
@@ -962,11 +939,6 @@ def create_modules(
962939 assert lora .lora_name not in names , f"duplicated lora name: { lora .lora_name } "
963940 names .add (lora .lora_name )
964941
965- def set_loraplus_lr_ratio (self , loraplus_lr_ratio , loraplus_unet_lr_ratio , loraplus_text_encoder_lr_ratio ):
966- self .loraplus_lr_ratio = loraplus_lr_ratio
967- self .loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
968- self .loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
969-
970942 def set_multiplier (self , multiplier ):
971943 self .multiplier = multiplier
972944 for lora in self .text_encoder_loras + self .unet_loras :
@@ -1065,42 +1037,18 @@ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
10651037 self .requires_grad_ (True )
10661038 all_params = []
10671039
1068- def assemble_params (loras , lr , ratio ):
1069- param_groups = {"lora" : {}, "plus" : {}}
1070- for lora in loras :
1071- for name , param in lora .named_parameters ():
1072- if ratio is not None and "lora_up" in name :
1073- param_groups ["plus" ][f"{ lora .lora_name } .{ name } " ] = param
1074- else :
1075- param_groups ["lora" ][f"{ lora .lora_name } .{ name } " ] = param
1076-
1040+ def enumerate_params (loras : List [LoRAModule ]):
10771041 params = []
1078- for key in param_groups .keys ():
1079- param_data = {"params" : param_groups [key ].values ()}
1080-
1081- if len (param_data ["params" ]) == 0 :
1082- continue
1083-
1084- if lr is not None :
1085- if key == "plus" :
1086- param_data ["lr" ] = lr * ratio
1087- else :
1088- param_data ["lr" ] = lr
1089-
1090- if param_data .get ("lr" , None ) == 0 or param_data .get ("lr" , None ) is None :
1091- continue
1092-
1093- params .append (param_data )
1094-
1042+ for lora in loras :
1043+ # params.extend(lora.parameters())
1044+ params .extend (lora .get_trainable_params ())
10951045 return params
10961046
10971047 if self .text_encoder_loras :
1098- params = assemble_params (
1099- self .text_encoder_loras ,
1100- text_encoder_lr if text_encoder_lr is not None else default_lr ,
1101- self .loraplus_text_encoder_lr_ratio or self .loraplus_ratio ,
1102- )
1103- all_params .extend (params )
1048+ param_data = {"params" : enumerate_params (self .text_encoder_loras )}
1049+ if text_encoder_lr is not None :
1050+ param_data ["lr" ] = text_encoder_lr
1051+ all_params .append (param_data )
11041052
11051053 if self .unet_loras :
11061054 if self .block_lr :
@@ -1114,20 +1062,21 @@ def assemble_params(loras, lr, ratio):
11141062
11151063 # blockごとにパラメータを設定する
11161064 for idx , block_loras in block_idx_to_lora .items ():
1117- params = assemble_params (
1118- block_loras ,
1119- (unet_lr if unet_lr is not None else default_lr ) * self .get_lr_weight (block_loras [0 ]),
1120- self .loraplus_unet_lr_ratio or self .loraplus_ratio ,
1121- )
1122- all_params .extend (params )
1065+ param_data = {"params" : enumerate_params (block_loras )}
1066+
1067+ if unet_lr is not None :
1068+ param_data ["lr" ] = unet_lr * self .get_lr_weight (block_loras [0 ])
1069+ elif default_lr is not None :
1070+ param_data ["lr" ] = default_lr * self .get_lr_weight (block_loras [0 ])
1071+ if ("lr" in param_data ) and (param_data ["lr" ] == 0 ):
1072+ continue
1073+ all_params .append (param_data )
11231074
11241075 else :
1125- params = assemble_params (
1126- self .unet_loras ,
1127- unet_lr if unet_lr is not None else default_lr ,
1128- self .loraplus_unet_lr_ratio or self .loraplus_ratio ,
1129- )
1130- all_params .extend (params )
1076+ param_data = {"params" : enumerate_params (self .unet_loras )}
1077+ if unet_lr is not None :
1078+ param_data ["lr" ] = unet_lr
1079+ all_params .append (param_data )
11311080
11321081 return all_params
11331082
@@ -1144,9 +1093,6 @@ def on_epoch_start(self, text_encoder, unet):
11441093 def get_trainable_params (self ):
11451094 return self .parameters ()
11461095
1147- def get_trainable_named_params (self ):
1148- return self .named_parameters ()
1149-
11501096 def save_weights (self , file , dtype , metadata ):
11511097 if metadata is not None and len (metadata ) == 0 :
11521098 metadata = None
0 commit comments