1
+ # coding=utf-8
2
+ # Copyright 2022 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import gc
17
+ import random
18
+ import unittest
19
+
20
+ import numpy as np
21
+ import oneflow as torch
22
+
23
+ from onediff import OneFlowStableDiffusionImg2ImgPipeline
24
+
25
+ from diffusers import (
26
+ AutoencoderKL ,
27
+ LMSDiscreteScheduler ,
28
+ PNDMScheduler ,
29
+ UNet2DConditionModel
30
+ )
31
+
32
+ from diffusers .utils import floats_tensor , load_image , torch_device
33
+ from transformers import CLIPTextConfig , CLIPTextModel , CLIPTokenizer
34
+
35
+
36
+ class PipelineFastTests (unittest .TestCase ):
37
+ def tearDown (self ):
38
+ super ().tearDown ()
39
+ gc .collect ()
40
+ torch .cuda .empty_cache ()
41
+
42
+ @property
43
+ def dummy_image (self ):
44
+ batch_size = 1
45
+ num_channels = 3
46
+ sizes = (32 , 32 )
47
+
48
+ image = floats_tensor ((batch_size , num_channels ) + sizes , rng = random .Random (0 ))
49
+ return torch .tensor (image )
50
+
51
+ @property
52
+ def dummy_cond_unet (self ):
53
+ torch .manual_seed (0 )
54
+ model = UNet2DConditionModel (
55
+ block_out_channels = (32 , 64 ),
56
+ layers_per_block = 2 ,
57
+ sample_size = 32 ,
58
+ in_channels = 4 ,
59
+ out_channels = 4 ,
60
+ down_block_types = ("DownBlock2D" , "CrossAttnDownBlock2D" ),
61
+ up_block_types = ("CrossAttnUpBlock2D" , "UpBlock2D" ),
62
+ cross_attention_dim = 32 ,
63
+ )
64
+ return model
65
+
66
+ @property
67
+ def dummy_vae (self ):
68
+ torch .manual_seed (0 )
69
+ model = AutoencoderKL (
70
+ block_out_channels = [32 , 64 ],
71
+ in_channels = 3 ,
72
+ out_channels = 3 ,
73
+ down_block_types = ["DownEncoderBlock2D" , "DownEncoderBlock2D" ],
74
+ up_block_types = ["UpDecoderBlock2D" , "UpDecoderBlock2D" ],
75
+ latent_channels = 4 ,
76
+ )
77
+ return model
78
+
79
+ @property
80
+ def dummy_text_encoder (self ):
81
+ torch .manual_seed (0 )
82
+ config = CLIPTextConfig (
83
+ bos_token_id = 0 ,
84
+ eos_token_id = 2 ,
85
+ hidden_size = 32 ,
86
+ intermediate_size = 37 ,
87
+ layer_norm_eps = 1e-05 ,
88
+ num_attention_heads = 4 ,
89
+ num_hidden_layers = 5 ,
90
+ pad_token_id = 1 ,
91
+ vocab_size = 1000 ,
92
+ )
93
+ return CLIPTextModel (config )
94
+
95
+ @property
96
+ def dummy_safety_checker (self ):
97
+ def check (images , * args , ** kwargs ):
98
+ return images , [False ] * len (images )
99
+
100
+ return check
101
+
102
+ @property
103
+ def dummy_extractor (self ):
104
+ def extract (* args , ** kwargs ):
105
+ if "return_tensors" in kwargs :
106
+ return_tensors = kwargs ["return_tensors" ]
107
+ else :
108
+ return_tensors = "pt"
109
+
110
+ class Out :
111
+ def __init__ (self ):
112
+ self .pixel_values = torch .ones ([0 ])
113
+ if return_tensors == "np" :
114
+ self .pixel_values = torch .ones ([0 ]).numpy ()
115
+
116
+ def to (self , device ):
117
+ if return_tensors == "np" :
118
+ return self
119
+ self .pixel_values .to (device )
120
+ return self
121
+
122
+ return Out ()
123
+
124
+ return extract
125
+
126
+ def test_stable_diffusion_img2img (self ):
127
+ unet = self .dummy_cond_unet .to (torch_device )
128
+ scheduler = PNDMScheduler (skip_prk_steps = True , steps_offset = 1 )
129
+ vae = self .dummy_vae .to (torch_device )
130
+ bert = self .dummy_text_encoder .to (torch_device )
131
+ tokenizer = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
132
+
133
+ init_image = self .dummy_image .to (torch_device )
134
+
135
+ # make sure here that pndm scheduler skips prk
136
+ sd_pipe = OneFlowStableDiffusionImg2ImgPipeline (
137
+ unet = unet ,
138
+ scheduler = scheduler ,
139
+ vae = vae ,
140
+ text_encoder = bert ,
141
+ tokenizer = tokenizer ,
142
+ safety_checker = self .dummy_safety_checker ,
143
+ feature_extractor = self .dummy_extractor ,
144
+ )
145
+ sd_pipe = sd_pipe .to (torch_device )
146
+ sd_pipe .set_progress_bar_config (disable = None )
147
+
148
+ prompt = "A painting of a squirrel eating a burger"
149
+ # prompt = "sea,beach,the waves crashed on the sand,blue sky whit white cloud"
150
+ generator = torch .Generator (device = torch_device ).manual_seed (0 )
151
+ output = sd_pipe (
152
+ [prompt ],
153
+ generator = generator ,
154
+ strength = 0.75 ,
155
+ guidance_scale = 7.5 ,
156
+ output_type = "np" ,
157
+ image = init_image ,
158
+ compile_unet = False
159
+ )
160
+ image = output .images
161
+
162
+ generator = torch .Generator (device = torch_device ).manual_seed (0 )
163
+ image_from_tuple = sd_pipe (
164
+ [prompt ],
165
+ generator = generator ,
166
+ strength = 0.75 ,
167
+ guidance_scale = 7.5 ,
168
+ output_type = "np" ,
169
+ image = init_image ,
170
+ return_dict = False ,
171
+ compile_unet = False
172
+ )[0 ]
173
+
174
+ image_slice = image [0 , - 3 :, - 3 :, - 1 ]
175
+ image_from_tuple_slice = image_from_tuple [0 , - 3 :, - 3 :, - 1 ]
176
+
177
+ assert image .shape == (1 , 32 , 32 , 3 )
178
+
179
+ # Do not modify any seed number to past this test
180
+ expected_slice = np .array ([0.4287 , 0.5450 , 0.5239 , 0.5432 , 0.6519 , 0.5665 , 0.6027 , 0.5805 , 0.5145 ])
181
+
182
+ assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
183
+ assert np .abs (image_from_tuple_slice .flatten () - expected_slice ).max () < 1e-2
184
+
185
+ def test_stable_diffusion_img2img_k_lms (self ):
186
+ unet = self .dummy_cond_unet
187
+ scheduler = LMSDiscreteScheduler (beta_start = 0.00085 , beta_end = 0.012 , beta_schedule = "scaled_linear" )
188
+
189
+ vae = self .dummy_vae
190
+ bert = self .dummy_text_encoder
191
+ tokenizer = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
192
+ init_image = self .dummy_image .to (torch_device )
193
+
194
+ # make sure here that pndm scheduler skips prk
195
+ sd_pipe = OneFlowStableDiffusionImg2ImgPipeline (
196
+ unet = unet ,
197
+ scheduler = scheduler ,
198
+ vae = vae ,
199
+ text_encoder = bert ,
200
+ tokenizer = tokenizer ,
201
+ safety_checker = self .dummy_safety_checker ,
202
+ feature_extractor = self .dummy_extractor ,
203
+ )
204
+ sd_pipe = sd_pipe .to (torch_device )
205
+ sd_pipe .set_progress_bar_config (disable = None )
206
+
207
+ prompt = "A painting of a squirrel eating a burger"
208
+ generator = torch .Generator (device = torch_device ).manual_seed (0 )
209
+ output = sd_pipe (
210
+ [prompt ],
211
+ image = init_image ,
212
+ generator = generator ,
213
+ strength = 0.75 ,
214
+ guidance_scale = 7.5 ,
215
+ output_type = "np" ,
216
+ compile_unet = False
217
+ )
218
+ image = output .images
219
+
220
+ generator = torch .Generator (device = torch_device ).manual_seed (0 )
221
+ output = sd_pipe (
222
+ [prompt ],
223
+ image = init_image ,
224
+ generator = generator ,
225
+ strength = 0.75 ,
226
+ guidance_scale = 7.5 ,
227
+ output_type = "np" ,
228
+ return_dict = False ,
229
+ compile_unet = False
230
+ )
231
+ image_from_tuple = output [0 ]
232
+
233
+ image_slice = image [0 , - 3 :, - 3 :, - 1 ]
234
+ image_from_tuple_slice = image_from_tuple [0 , - 3 :, - 3 :, - 1 ]
235
+
236
+ assert image .shape == (1 , 32 , 32 , 3 )
237
+
238
+ # Do not modify any seed number to past this test
239
+ expected_slice = np .array ([0.4213 , 0.5489 , 0.5102 , 0.5320 , 0.6574 , 0.5861 , 0.6171 , 0.5866 , 0.5160 ])
240
+
241
+ assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
242
+ assert np .abs (image_from_tuple_slice .flatten () - expected_slice ).max () < 1e-2
243
+
244
+ def test_stable_diffusion_img2img_pipeline (self ):
245
+ init_image = load_image (
246
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
247
+ "/img2img/sketch-mountains-input.jpg"
248
+ )
249
+ expected_image = load_image (
250
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
251
+ "/img2img/fantasy_landscape.png"
252
+ )
253
+ init_image = init_image .resize ((768 , 512 ))
254
+ expected_image = np .array (expected_image , dtype = np .float32 ) / 255.0
255
+
256
+ model_id = "CompVis/stable-diffusion-v1-4"
257
+ pipe = OneFlowStableDiffusionImg2ImgPipeline .from_pretrained (
258
+ model_id ,
259
+ safety_checker = self .dummy_safety_checker ,
260
+ use_auth_token = True ,
261
+ )
262
+ pipe .to (torch_device )
263
+ pipe .set_progress_bar_config (disable = None )
264
+ pipe .enable_attention_slicing ()
265
+
266
+ prompt = "A fantasy landscape, trending on artstation"
267
+
268
+ generator = torch .Generator (device = torch_device ).manual_seed (0 )
269
+ output = pipe (
270
+ prompt = prompt ,
271
+ image = init_image ,
272
+ strength = 0.75 ,
273
+ guidance_scale = 7.5 ,
274
+ generator = generator ,
275
+ output_type = "np" ,
276
+ compile_unet = False
277
+ )
278
+ image = output .images [0 ]
279
+
280
+ assert image .shape == (512 , 768 , 3 )
281
+ # img2img is flaky across GPUs even in fp32, so using MAE here
282
+ assert np .abs (expected_image - image ).mean () < 1e-2
283
+
284
+ def test_stable_diffusion_img2img_pipeline_k_lms (self ):
285
+ init_image = load_image (
286
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
287
+ "/img2img/sketch-mountains-input.jpg"
288
+ )
289
+ expected_image = load_image (
290
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
291
+ "/img2img/fantasy_landscape_k_lms.png"
292
+ )
293
+ init_image = init_image .resize ((768 , 512 ))
294
+ expected_image = np .array (expected_image , dtype = np .float32 ) / 255.0
295
+
296
+ lms = LMSDiscreteScheduler (beta_start = 0.00085 , beta_end = 0.012 , beta_schedule = "scaled_linear" )
297
+
298
+ model_id = "CompVis/stable-diffusion-v1-4"
299
+ pipe = OneFlowStableDiffusionImg2ImgPipeline .from_pretrained (
300
+ model_id ,
301
+ scheduler = lms ,
302
+ safety_checker = self .dummy_safety_checker ,
303
+ use_auth_token = True ,
304
+ )
305
+ pipe .to (torch_device )
306
+ pipe .set_progress_bar_config (disable = None )
307
+ pipe .enable_attention_slicing ()
308
+
309
+ prompt = "A fantasy landscape, trending on artstation"
310
+
311
+ generator = torch .Generator (device = torch_device ).manual_seed (0 )
312
+ output = pipe (
313
+ prompt = prompt ,
314
+ image = init_image ,
315
+ strength = 0.75 ,
316
+ guidance_scale = 7.5 ,
317
+ generator = generator ,
318
+ output_type = "np" ,
319
+ compile_unet = False
320
+ )
321
+ image = output .images [0 ]
322
+
323
+ assert image .shape == (512 , 768 , 3 )
324
+ # img2img is flaky across GPUs even in fp32, so using MAE here
325
+ assert np .abs (expected_image - image ).mean () < 1e-2
326
+
327
+ if __name__ == "__main__" :
328
+ unittest .main ()
0 commit comments