24
24
from tqdm import tqdm
25
25
26
26
27
- def mock_wrapper (f ):
28
- import sys
27
+ class MockCtx (object ):
28
+ def __enter__ (self ):
29
+ flow .mock_torch .enable (lazy = True )
29
30
30
- flow .mock_torch .enable (lazy = True )
31
- ret = f ()
32
- flow .mock_torch .disable ()
33
- # TODO: this trick of py mod purging will be removed
34
- tmp = sys .modules .copy ()
35
- for x in tmp :
36
- if x .startswith ("diffusers" ):
37
- del sys .modules [x ]
38
- return ret
31
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
32
+ flow .mock_torch .disable ()
39
33
40
34
41
- class UNetGraph (flow .nn .Graph ):
35
+ def get_unet (token ):
36
+ from diffusers import UNet2DConditionModel
37
+
38
+ unet = UNet2DConditionModel .from_pretrained (
39
+ "runwayml/stable-diffusion-v1-5" ,
40
+ use_auth_token = token ,
41
+ revision = "fp16" ,
42
+ torch_dtype = flow .float16 ,
43
+ subfolder = "unet" ,
44
+ )
45
+ with flow .no_grad ():
46
+ unet = unet .to ("cuda" )
47
+ return unet
48
+
49
+
50
+ class UNetGraphWithCache (flow .nn .Graph ):
51
+ @flow .nn .Graph .with_dynamic_input_shape (size = 9 )
42
52
def __init__ (self , unet ):
43
- super ().__init__ ()
53
+ super ().__init__ (enable_get_runtime_state_dict = True )
44
54
self .unet = unet
45
55
self .config .enable_cudnn_conv_heuristic_search_algo (False )
46
56
self .config .allow_fuse_add_to_output (True )
@@ -51,58 +61,108 @@ def build(self, latent_model_input, t, text_embeddings):
51
61
latent_model_input , t , encoder_hidden_states = text_embeddings
52
62
).sample
53
63
64
+ def warmup_with_arg (self , arg_meta_of_sizes ):
65
+ for arg_metas in arg_meta_of_sizes :
66
+ print (f"warmup { arg_metas = } " )
67
+ arg_tensors = [flow .empty (a [0 ], dtype = a [1 ]).to ("cuda" ) for a in arg_metas ]
68
+ self (* arg_tensors ) # build and warmup
54
69
55
- def get_graph (token ):
56
- from diffusers import UNet2DConditionModel
70
+ def warmup_with_load (self , file_path ):
71
+ state_dict = flow .load (file_path )
72
+ self .load_runtime_state_dict (state_dict )
57
73
58
- with flow .no_grad ():
59
- unet = UNet2DConditionModel .from_pretrained (
60
- "runwayml/stable-diffusion-v1-5" ,
61
- use_auth_token = token ,
62
- revision = "fp16" ,
63
- torch_dtype = flow .float16 ,
64
- subfolder = "unet" ,
65
- )
66
- unet = unet .to ("cuda" )
67
- return UNetGraph (unet )
74
+ def save_graph (self , file_path ):
75
+ state_dict = self .runtime_state_dict ()
76
+ flow .save (state_dict , file_path )
77
+
78
+
79
+ def image_dim (i ):
80
+ return 768 + 128 * i
81
+
82
+
83
+ def noise_shape (batch_size , num_channels , image_w , image_h ):
84
+ sizes = (image_w // 8 , image_h // 8 )
85
+ return (batch_size , num_channels ) + sizes
86
+
87
+
88
+ def get_arg_meta_of_sizes (batch_sizes , resolution_scales , num_channels ):
89
+ return [
90
+ [
91
+ (
92
+ noise_shape (batch_size , num_channels , image_dim (i ), image_dim (j )),
93
+ flow .float16 ,
94
+ ),
95
+ ((1 ,), flow .int64 ),
96
+ ((batch_size , 77 , 768 ), flow .float16 ),
97
+ ]
98
+ for batch_size in batch_sizes
99
+ for i in resolution_scales
100
+ for j in resolution_scales
101
+ ]
68
102
69
103
70
104
@click .command ()
71
105
@click .option ("--token" )
72
- @click .option ("--repeat" , default = 1000 )
106
+ @click .option ("--repeat" , default = 100 )
73
107
@click .option ("--sync_interval" , default = 50 )
74
- def benchmark (token , repeat , sync_interval ):
108
+ @click .option ("--save" , is_flag = True )
109
+ @click .option ("--load" , is_flag = True )
110
+ @click .option ("--file" , type = str , default = "./unet_graphs" )
111
+ def benchmark (token , repeat , sync_interval , save , load , file ):
112
+ RESOLUTION_SCALES = [2 , 1 , 0 ]
113
+ BATCH_SIZES = [2 ]
114
+ # TODO: reproduce bug caused by changing batch
115
+ # BATCH_SIZES = [4, 2]
116
+
75
117
# create a mocked unet graph
76
- unet_graph = mock_wrapper (lambda : get_graph (token ))
118
+ num_channels = 4
119
+
120
+ warmup_meta_of_sizes = get_arg_meta_of_sizes (BATCH_SIZES , RESOLUTION_SCALES , num_channels )
121
+ for (i , m ) in enumerate (warmup_meta_of_sizes ):
122
+ print (f"warmup case #{ i + 1 } :" , m )
123
+ with MockCtx ():
124
+ unet = get_unet (token )
125
+ unet_graph = UNetGraphWithCache (unet )
126
+ if load == True :
127
+ print ("loading graphs..." )
128
+ unet_graph .warmup_with_load (file )
129
+ else :
130
+ print ("warmup with arguments..." )
131
+ unet_graph .warmup_with_arg (warmup_meta_of_sizes )
77
132
78
133
# generate inputs with torch
79
134
from diffusers .utils import floats_tensor
80
135
import torch
81
136
82
- batch_size = 2
83
- num_channels = 4
84
- sizes = (64 , 64 )
85
- noise = (
86
- floats_tensor ((batch_size , num_channels ) + sizes ).to ("cuda" ).to (torch .float16 )
87
- )
88
- print (f"{ type (noise )= } " )
89
137
time_step = torch .tensor ([10 ]).to ("cuda" )
90
- encoder_hidden_states = (
91
- floats_tensor ((batch_size , 77 , 768 )).to ("cuda" ).to (torch .float16 )
92
- )
93
-
94
- # convert to oneflow tensors
95
- [noise , time_step , encoder_hidden_states ] = [
96
- flow .utils .tensor .from_torch (x )
97
- for x in [noise , time_step , encoder_hidden_states ]
138
+ encoder_hidden_states_of_sizes = {
139
+ batch_size : floats_tensor ((batch_size , 77 , 768 )).to ("cuda" ).to (torch .float16 )
140
+ for batch_size in BATCH_SIZES
141
+ }
142
+ noise_of_sizes = [
143
+ floats_tensor (noise_shape (batch_size , num_channels , image_dim (i ), image_dim (j )))
144
+ .to ("cuda" )
145
+ .to (torch .float16 )
146
+ for batch_size in BATCH_SIZES
147
+ for i in RESOLUTION_SCALES
148
+ for j in RESOLUTION_SCALES
98
149
]
99
- unet_graph (noise , time_step , encoder_hidden_states )
150
+ noise_of_sizes = [flow .utils .tensor .from_torch (x ) for x in noise_of_sizes ]
151
+ encoder_hidden_states_of_sizes = {
152
+ k : flow .utils .tensor .from_torch (v ) for k , v in encoder_hidden_states_of_sizes .items ()
153
+ }
154
+ # convert to oneflow tensors
155
+ time_step = flow .utils .tensor .from_torch (time_step )
100
156
101
157
flow ._oneflow_internal .eager .Sync ()
102
158
import time
103
159
104
160
t0 = time .time ()
105
161
for r in tqdm (range (repeat )):
162
+ import random
163
+
164
+ noise = random .choice (noise_of_sizes )
165
+ encoder_hidden_states = encoder_hidden_states_of_sizes [noise .shape [0 ]]
106
166
out = unet_graph (noise , time_step , encoder_hidden_states )
107
167
# convert to torch tensors
108
168
out = flow .utils .tensor .to_torch (out )
@@ -116,6 +176,10 @@ def benchmark(token, repeat, sync_interval):
116
176
f"Finish { repeat } steps in { duration :.3f} seconds, average { throughput :.2f} it/s"
117
177
)
118
178
179
+ if save :
180
+ print ("saving graphs..." )
181
+ unet_graph .save_graph (file )
182
+
119
183
120
184
if __name__ == "__main__" :
121
185
print (f"{ flow .__path__ = } " )
0 commit comments