Skip to content

Commit 3c8193f

Browse files
committed
revert lora+ for lora_fa
1 parent c6a4370 commit 3c8193f

File tree

1 file changed

+25
-79
lines changed

1 file changed

+25
-79
lines changed

networks/lora_fa.py

Lines changed: 25 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515
import torch
1616
import re
1717
from library.utils import setup_logging
18-
1918
setup_logging()
2019
import logging
21-
2220
logger = logging.getLogger(__name__)
2321

2422
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -506,15 +504,6 @@ def create_network(
506504
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
507505
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
508506

509-
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
510-
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
511-
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
512-
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
513-
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
514-
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
515-
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
516-
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
517-
518507
return network
519508

520509

@@ -540,9 +529,7 @@ def parse_floats(s):
540529
len(block_dims) == num_total_blocks
541530
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
542531
else:
543-
logger.warning(
544-
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
545-
)
532+
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
546533
block_dims = [network_dim] * num_total_blocks
547534

548535
if block_alphas is not None:
@@ -816,31 +803,21 @@ def __init__(
816803
self.rank_dropout = rank_dropout
817804
self.module_dropout = module_dropout
818805

819-
self.loraplus_lr_ratio = None
820-
self.loraplus_unet_lr_ratio = None
821-
self.loraplus_text_encoder_lr_ratio = None
822-
823806
if modules_dim is not None:
824807
logger.info(f"create LoRA network from weights")
825808
elif block_dims is not None:
826809
logger.info(f"create LoRA network from block_dims")
827-
logger.info(
828-
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
829-
)
810+
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
830811
logger.info(f"block_dims: {block_dims}")
831812
logger.info(f"block_alphas: {block_alphas}")
832813
if conv_block_dims is not None:
833814
logger.info(f"conv_block_dims: {conv_block_dims}")
834815
logger.info(f"conv_block_alphas: {conv_block_alphas}")
835816
else:
836817
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
837-
logger.info(
838-
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
839-
)
818+
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
840819
if self.conv_lora_dim is not None:
841-
logger.info(
842-
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
843-
)
820+
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
844821

845822
# create module instances
846823
def create_modules(
@@ -962,11 +939,6 @@ def create_modules(
962939
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
963940
names.add(lora.lora_name)
964941

965-
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
966-
self.loraplus_lr_ratio = loraplus_lr_ratio
967-
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
968-
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
969-
970942
def set_multiplier(self, multiplier):
971943
self.multiplier = multiplier
972944
for lora in self.text_encoder_loras + self.unet_loras:
@@ -1065,42 +1037,18 @@ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
10651037
self.requires_grad_(True)
10661038
all_params = []
10671039

1068-
def assemble_params(loras, lr, ratio):
1069-
param_groups = {"lora": {}, "plus": {}}
1070-
for lora in loras:
1071-
for name, param in lora.named_parameters():
1072-
if ratio is not None and "lora_up" in name:
1073-
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
1074-
else:
1075-
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
1076-
1040+
def enumerate_params(loras: List[LoRAModule]):
10771041
params = []
1078-
for key in param_groups.keys():
1079-
param_data = {"params": param_groups[key].values()}
1080-
1081-
if len(param_data["params"]) == 0:
1082-
continue
1083-
1084-
if lr is not None:
1085-
if key == "plus":
1086-
param_data["lr"] = lr * ratio
1087-
else:
1088-
param_data["lr"] = lr
1089-
1090-
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
1091-
continue
1092-
1093-
params.append(param_data)
1094-
1042+
for lora in loras:
1043+
# params.extend(lora.parameters())
1044+
params.extend(lora.get_trainable_params())
10951045
return params
10961046

10971047
if self.text_encoder_loras:
1098-
params = assemble_params(
1099-
self.text_encoder_loras,
1100-
text_encoder_lr if text_encoder_lr is not None else default_lr,
1101-
self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio,
1102-
)
1103-
all_params.extend(params)
1048+
param_data = {"params": enumerate_params(self.text_encoder_loras)}
1049+
if text_encoder_lr is not None:
1050+
param_data["lr"] = text_encoder_lr
1051+
all_params.append(param_data)
11041052

11051053
if self.unet_loras:
11061054
if self.block_lr:
@@ -1114,20 +1062,21 @@ def assemble_params(loras, lr, ratio):
11141062

11151063
# blockごとにパラメータを設定する
11161064
for idx, block_loras in block_idx_to_lora.items():
1117-
params = assemble_params(
1118-
block_loras,
1119-
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
1120-
self.loraplus_unet_lr_ratio or self.loraplus_ratio,
1121-
)
1122-
all_params.extend(params)
1065+
param_data = {"params": enumerate_params(block_loras)}
1066+
1067+
if unet_lr is not None:
1068+
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
1069+
elif default_lr is not None:
1070+
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
1071+
if ("lr" in param_data) and (param_data["lr"] == 0):
1072+
continue
1073+
all_params.append(param_data)
11231074

11241075
else:
1125-
params = assemble_params(
1126-
self.unet_loras,
1127-
unet_lr if unet_lr is not None else default_lr,
1128-
self.loraplus_unet_lr_ratio or self.loraplus_ratio,
1129-
)
1130-
all_params.extend(params)
1076+
param_data = {"params": enumerate_params(self.unet_loras)}
1077+
if unet_lr is not None:
1078+
param_data["lr"] = unet_lr
1079+
all_params.append(param_data)
11311080

11321081
return all_params
11331082

@@ -1144,9 +1093,6 @@ def on_epoch_start(self, text_encoder, unet):
11441093
def get_trainable_params(self):
11451094
return self.parameters()
11461095

1147-
def get_trainable_named_params(self):
1148-
return self.named_parameters()
1149-
11501096
def save_weights(self, file, dtype, metadata):
11511097
if metadata is not None and len(metadata) == 0:
11521098
metadata = None

0 commit comments

Comments
 (0)