File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -272,7 +272,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
272272 # 学習を準備する:モデルを適切な状態にする
273273 if args .gradient_checkpointing :
274274 unet .enable_gradient_checkpointing ()
275- train_unet = args .learning_rate > 0
275+ train_unet = args .learning_rate != 0
276276 train_text_encoder1 = False
277277 train_text_encoder2 = False
278278
@@ -284,8 +284,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
284284 text_encoder2 .gradient_checkpointing_enable ()
285285 lr_te1 = args .learning_rate_te1 if args .learning_rate_te1 is not None else args .learning_rate # 0 means not train
286286 lr_te2 = args .learning_rate_te2 if args .learning_rate_te2 is not None else args .learning_rate # 0 means not train
287- train_text_encoder1 = lr_te1 > 0
288- train_text_encoder2 = lr_te2 > 0
287+ train_text_encoder1 = lr_te1 != 0
288+ train_text_encoder2 = lr_te2 != 0
289289
290290 # caching one text encoder output is not supported
291291 if not train_text_encoder1 :
You can’t perform that action at this time.
0 commit comments