55import random
66import time
77from multiprocessing import Value
8- from types import SimpleNamespace
8+ from omegaconf import OmegaConf
99import toml
1010
1111from tqdm import tqdm
@@ -148,8 +148,10 @@ def train(args):
148148 "in_channels" : 4 ,
149149 "layers_per_block" : 2 ,
150150 "mid_block_scale_factor" : 1 ,
151+ "mid_block_type" : "UNetMidBlock2DCrossAttn" ,
151152 "norm_eps" : 1e-05 ,
152153 "norm_num_groups" : 32 ,
154+ "num_attention_heads" : [5 , 10 , 20 , 20 ],
153155 "num_class_embeds" : None ,
154156 "only_cross_attention" : False ,
155157 "out_channels" : 4 ,
@@ -179,8 +181,10 @@ def train(args):
179181 "in_channels" : 4 ,
180182 "layers_per_block" : 2 ,
181183 "mid_block_scale_factor" : 1 ,
184+ "mid_block_type" : "UNetMidBlock2DCrossAttn" ,
182185 "norm_eps" : 1e-05 ,
183186 "norm_num_groups" : 32 ,
187+ "num_attention_heads" : 8 ,
184188 "out_channels" : 4 ,
185189 "sample_size" : 64 ,
186190 "up_block_types" : ["UpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" ],
@@ -193,7 +197,7 @@ def train(args):
193197 "resnet_time_scale_shift" : "default" ,
194198 "projection_class_embeddings_input_dim" : None ,
195199 }
196- unet .config = SimpleNamespace ( ** unet .config )
200+ unet .config = OmegaConf . create ( unet .config )
197201
198202 controlnet = ControlNetModel .from_unet (unet )
199203
0 commit comments