Skip to content

Commit 49e3946

Browse files
authored
Add img2img test and model migration (#148)
1. 增加了image to image的非graph load的测试用例(详细测试版本在tests文件夹下) 2. 从fork迁移了部分模型(oneflow版本): vae, unet_2d_cond, attention, embeddings, resnet 3. 从fork迁移了部分scheduler(oneflow版本): lms_scheduler, pndm_scheduler 4. 迁移相应的utils模块以支持1, 2, 3的迁移
1 parent 73519de commit 49e3946

File tree

4 files changed

+358
-0
lines changed

4 files changed

+358
-0
lines changed

examples/image_to_image.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import oneflow as flow
2+
from PIL import Image
3+
flow.mock_torch.enable()
4+
from onediff import OneFlowStableDiffusionImg2ImgPipeline
5+
6+
pipe = OneFlowStableDiffusionImg2ImgPipeline.from_pretrained(
7+
"stabilityai/stable-diffusion-2",
8+
use_auth_token=True,
9+
revision="fp16",
10+
torch_dtype=flow.float16,
11+
)
12+
13+
pipe = pipe.to("cuda")
14+
15+
prompt = "sea,beach,the waves crashed on the sand,blue sky whit white cloud"
16+
17+
img = Image.new("RGB", (512, 512), "#1f80f0")
18+
19+
with flow.autocast("cuda"):
20+
images = pipe(
21+
prompt,
22+
image=img,
23+
guidance_scale=10,
24+
num_inference_steps=100,
25+
compile_unet=False,
26+
output_type="np",
27+
).images
28+
for i, image in enumerate(images):
29+
pipe.numpy_to_pil(image)[0].save(f"{prompt}-of-{i}.png")

src/onediff/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ def dummy_randn(*args, **kwargs):
3131
from .pipeline_stable_diffusion_inpaint_oneflow import (
3232
OneFlowStableDiffusionInpaintPipeline,
3333
)
34+
Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
# coding=utf-8
2+
# Copyright 2022 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
import random
18+
import unittest
19+
20+
import numpy as np
21+
import oneflow as torch
22+
23+
from onediff import OneFlowStableDiffusionImg2ImgPipeline
24+
25+
from diffusers import (
26+
AutoencoderKL,
27+
LMSDiscreteScheduler,
28+
PNDMScheduler,
29+
UNet2DConditionModel
30+
)
31+
32+
from diffusers.utils import floats_tensor, load_image, torch_device
33+
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
34+
35+
36+
class PipelineFastTests(unittest.TestCase):
37+
def tearDown(self):
38+
super().tearDown()
39+
gc.collect()
40+
torch.cuda.empty_cache()
41+
42+
@property
43+
def dummy_image(self):
44+
batch_size = 1
45+
num_channels = 3
46+
sizes = (32, 32)
47+
48+
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0))
49+
return torch.tensor(image)
50+
51+
@property
52+
def dummy_cond_unet(self):
53+
torch.manual_seed(0)
54+
model = UNet2DConditionModel(
55+
block_out_channels=(32, 64),
56+
layers_per_block=2,
57+
sample_size=32,
58+
in_channels=4,
59+
out_channels=4,
60+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
61+
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
62+
cross_attention_dim=32,
63+
)
64+
return model
65+
66+
@property
67+
def dummy_vae(self):
68+
torch.manual_seed(0)
69+
model = AutoencoderKL(
70+
block_out_channels=[32, 64],
71+
in_channels=3,
72+
out_channels=3,
73+
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
74+
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
75+
latent_channels=4,
76+
)
77+
return model
78+
79+
@property
80+
def dummy_text_encoder(self):
81+
torch.manual_seed(0)
82+
config = CLIPTextConfig(
83+
bos_token_id=0,
84+
eos_token_id=2,
85+
hidden_size=32,
86+
intermediate_size=37,
87+
layer_norm_eps=1e-05,
88+
num_attention_heads=4,
89+
num_hidden_layers=5,
90+
pad_token_id=1,
91+
vocab_size=1000,
92+
)
93+
return CLIPTextModel(config)
94+
95+
@property
96+
def dummy_safety_checker(self):
97+
def check(images, *args, **kwargs):
98+
return images, [False] * len(images)
99+
100+
return check
101+
102+
@property
103+
def dummy_extractor(self):
104+
def extract(*args, **kwargs):
105+
if "return_tensors" in kwargs:
106+
return_tensors = kwargs["return_tensors"]
107+
else:
108+
return_tensors = "pt"
109+
110+
class Out:
111+
def __init__(self):
112+
self.pixel_values = torch.ones([0])
113+
if return_tensors == "np":
114+
self.pixel_values = torch.ones([0]).numpy()
115+
116+
def to(self, device):
117+
if return_tensors == "np":
118+
return self
119+
self.pixel_values.to(device)
120+
return self
121+
122+
return Out()
123+
124+
return extract
125+
126+
def test_stable_diffusion_img2img(self):
127+
unet = self.dummy_cond_unet.to(torch_device)
128+
scheduler = PNDMScheduler(skip_prk_steps=True, steps_offset=1)
129+
vae = self.dummy_vae.to(torch_device)
130+
bert = self.dummy_text_encoder.to(torch_device)
131+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
132+
133+
init_image = self.dummy_image.to(torch_device)
134+
135+
# make sure here that pndm scheduler skips prk
136+
sd_pipe = OneFlowStableDiffusionImg2ImgPipeline(
137+
unet=unet,
138+
scheduler=scheduler,
139+
vae=vae,
140+
text_encoder=bert,
141+
tokenizer=tokenizer,
142+
safety_checker=self.dummy_safety_checker,
143+
feature_extractor=self.dummy_extractor,
144+
)
145+
sd_pipe = sd_pipe.to(torch_device)
146+
sd_pipe.set_progress_bar_config(disable=None)
147+
148+
prompt = "A painting of a squirrel eating a burger"
149+
# prompt = "sea,beach,the waves crashed on the sand,blue sky whit white cloud"
150+
generator = torch.Generator(device=torch_device).manual_seed(0)
151+
output = sd_pipe(
152+
[prompt],
153+
generator=generator,
154+
strength=0.75,
155+
guidance_scale=7.5,
156+
output_type="np",
157+
image=init_image,
158+
compile_unet=False
159+
)
160+
image = output.images
161+
162+
generator = torch.Generator(device=torch_device).manual_seed(0)
163+
image_from_tuple = sd_pipe(
164+
[prompt],
165+
generator=generator,
166+
strength=0.75,
167+
guidance_scale=7.5,
168+
output_type="np",
169+
image=init_image,
170+
return_dict=False,
171+
compile_unet=False
172+
)[0]
173+
174+
image_slice = image[0, -3:, -3:, -1]
175+
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
176+
177+
assert image.shape == (1, 32, 32, 3)
178+
179+
# Do not modify any seed number to past this test
180+
expected_slice = np.array([0.4287, 0.5450, 0.5239, 0.5432, 0.6519, 0.5665, 0.6027, 0.5805, 0.5145])
181+
182+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
183+
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
184+
185+
def test_stable_diffusion_img2img_k_lms(self):
186+
unet = self.dummy_cond_unet
187+
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
188+
189+
vae = self.dummy_vae
190+
bert = self.dummy_text_encoder
191+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
192+
init_image = self.dummy_image.to(torch_device)
193+
194+
# make sure here that pndm scheduler skips prk
195+
sd_pipe = OneFlowStableDiffusionImg2ImgPipeline(
196+
unet=unet,
197+
scheduler=scheduler,
198+
vae=vae,
199+
text_encoder=bert,
200+
tokenizer=tokenizer,
201+
safety_checker=self.dummy_safety_checker,
202+
feature_extractor=self.dummy_extractor,
203+
)
204+
sd_pipe = sd_pipe.to(torch_device)
205+
sd_pipe.set_progress_bar_config(disable=None)
206+
207+
prompt = "A painting of a squirrel eating a burger"
208+
generator = torch.Generator(device=torch_device).manual_seed(0)
209+
output = sd_pipe(
210+
[prompt],
211+
image = init_image,
212+
generator=generator,
213+
strength=0.75,
214+
guidance_scale=7.5,
215+
output_type="np",
216+
compile_unet=False
217+
)
218+
image = output.images
219+
220+
generator = torch.Generator(device=torch_device).manual_seed(0)
221+
output = sd_pipe(
222+
[prompt],
223+
image = init_image,
224+
generator=generator,
225+
strength=0.75,
226+
guidance_scale=7.5,
227+
output_type="np",
228+
return_dict=False,
229+
compile_unet=False
230+
)
231+
image_from_tuple = output[0]
232+
233+
image_slice = image[0, -3:, -3:, -1]
234+
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
235+
236+
assert image.shape == (1, 32, 32, 3)
237+
238+
# Do not modify any seed number to past this test
239+
expected_slice = np.array([0.4213, 0.5489, 0.5102, 0.5320, 0.6574, 0.5861, 0.6171, 0.5866, 0.5160])
240+
241+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
242+
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
243+
244+
def test_stable_diffusion_img2img_pipeline(self):
245+
init_image = load_image(
246+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
247+
"/img2img/sketch-mountains-input.jpg"
248+
)
249+
expected_image = load_image(
250+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
251+
"/img2img/fantasy_landscape.png"
252+
)
253+
init_image = init_image.resize((768, 512))
254+
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
255+
256+
model_id = "CompVis/stable-diffusion-v1-4"
257+
pipe = OneFlowStableDiffusionImg2ImgPipeline.from_pretrained(
258+
model_id,
259+
safety_checker=self.dummy_safety_checker,
260+
use_auth_token=True,
261+
)
262+
pipe.to(torch_device)
263+
pipe.set_progress_bar_config(disable=None)
264+
pipe.enable_attention_slicing()
265+
266+
prompt = "A fantasy landscape, trending on artstation"
267+
268+
generator = torch.Generator(device=torch_device).manual_seed(0)
269+
output = pipe(
270+
prompt=prompt,
271+
image=init_image,
272+
strength=0.75,
273+
guidance_scale=7.5,
274+
generator=generator,
275+
output_type="np",
276+
compile_unet=False
277+
)
278+
image = output.images[0]
279+
280+
assert image.shape == (512, 768, 3)
281+
# img2img is flaky across GPUs even in fp32, so using MAE here
282+
assert np.abs(expected_image - image).mean() < 1e-2
283+
284+
def test_stable_diffusion_img2img_pipeline_k_lms(self):
285+
init_image = load_image(
286+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
287+
"/img2img/sketch-mountains-input.jpg"
288+
)
289+
expected_image = load_image(
290+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
291+
"/img2img/fantasy_landscape_k_lms.png"
292+
)
293+
init_image = init_image.resize((768, 512))
294+
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
295+
296+
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
297+
298+
model_id = "CompVis/stable-diffusion-v1-4"
299+
pipe = OneFlowStableDiffusionImg2ImgPipeline.from_pretrained(
300+
model_id,
301+
scheduler=lms,
302+
safety_checker=self.dummy_safety_checker,
303+
use_auth_token=True,
304+
)
305+
pipe.to(torch_device)
306+
pipe.set_progress_bar_config(disable=None)
307+
pipe.enable_attention_slicing()
308+
309+
prompt = "A fantasy landscape, trending on artstation"
310+
311+
generator = torch.Generator(device=torch_device).manual_seed(0)
312+
output = pipe(
313+
prompt=prompt,
314+
image=init_image,
315+
strength=0.75,
316+
guidance_scale=7.5,
317+
generator=generator,
318+
output_type="np",
319+
compile_unet=False
320+
)
321+
image = output.images[0]
322+
323+
assert image.shape == (512, 768, 3)
324+
# img2img is flaky across GPUs even in fp32, so using MAE here
325+
assert np.abs(expected_image - image).mean() < 1e-2
326+
327+
if __name__ == "__main__":
328+
unittest.main()

0 commit comments

Comments
 (0)