1111
1212import torch
1313from library .device_utils import init_ipex , get_preferred_device
14+
1415init_ipex ()
1516
1617from torchvision import transforms
1718
1819import library .model_util as model_util
1920import library .train_util as train_util
2021from library .utils import setup_logging
22+
2123setup_logging ()
2224import logging
25+
2326logger = logging .getLogger (__name__ )
2427
2528DEVICE = get_preferred_device ()
@@ -89,7 +92,9 @@ def main(args):
8992
9093 # bucketのサイズを計算する
9194 max_reso = tuple ([int (t ) for t in args .max_resolution .split ("," )])
92- assert len (max_reso ) == 2 , f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: { args .max_resolution } "
95+ assert (
96+ len (max_reso ) == 2
97+ ), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: { args .max_resolution } "
9398
9499 bucket_manager = train_util .BucketManager (
95100 args .bucket_no_upscale , max_reso , args .min_bucket_reso , args .max_bucket_reso , args .bucket_reso_steps
@@ -107,7 +112,7 @@ def main(args):
107112 def process_batch (is_last ):
108113 for bucket in bucket_manager .buckets :
109114 if (is_last and len (bucket ) > 0 ) or len (bucket ) >= args .batch_size :
110- train_util .cache_batch_latents (vae , True , bucket , args .flip_aug , False )
115+ train_util .cache_batch_latents (vae , True , bucket , args .flip_aug , args . alpha_mask , False )
111116 bucket .clear ()
112117
113118 # 読み込みの高速化のためにDataLoaderを使うオプション
@@ -208,7 +213,9 @@ def setup_parser() -> argparse.ArgumentParser:
208213 parser .add_argument ("in_json" , type = str , help = "metadata file to input / 読み込むメタデータファイル" )
209214 parser .add_argument ("out_json" , type = str , help = "metadata file to output / メタデータファイル書き出し先" )
210215 parser .add_argument ("model_name_or_path" , type = str , help = "model name or path to encode latents / latentを取得するためのモデル" )
211- parser .add_argument ("--v2" , action = "store_true" , help = "not used (for backward compatibility) / 使用されません(互換性のため残してあります)" )
216+ parser .add_argument (
217+ "--v2" , action = "store_true" , help = "not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
218+ )
212219 parser .add_argument ("--batch_size" , type = int , default = 1 , help = "batch size in inference / 推論時のバッチサイズ" )
213220 parser .add_argument (
214221 "--max_data_loader_n_workers" ,
@@ -231,18 +238,32 @@ def setup_parser() -> argparse.ArgumentParser:
231238 help = "steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します" ,
232239 )
233240 parser .add_argument (
234- "--bucket_no_upscale" , action = "store_true" , help = "make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
241+ "--bucket_no_upscale" ,
242+ action = "store_true" ,
243+ help = "make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" ,
235244 )
236245 parser .add_argument (
237- "--mixed_precision" , type = str , default = "no" , choices = ["no" , "fp16" , "bf16" ], help = "use mixed precision / 混合精度を使う場合、その精度"
246+ "--mixed_precision" ,
247+ type = str ,
248+ default = "no" ,
249+ choices = ["no" , "fp16" , "bf16" ],
250+ help = "use mixed precision / 混合精度を使う場合、その精度" ,
238251 )
239252 parser .add_argument (
240253 "--full_path" ,
241254 action = "store_true" ,
242255 help = "use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)" ,
243256 )
244257 parser .add_argument (
245- "--flip_aug" , action = "store_true" , help = "flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
258+ "--flip_aug" ,
259+ action = "store_true" ,
260+ help = "flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する" ,
261+ )
262+ parser .add_argument (
263+ "--alpha_mask" ,
264+ type = str ,
265+ default = "" ,
266+ help = "save alpha mask for images for loss calculation / 損失計算用に画像のアルファマスクを保存する" ,
246267 )
247268 parser .add_argument (
248269 "--skip_existing" ,
0 commit comments