Skip to content

Commit e469262

Browse files
jackalcooperstrint
andauthored
Add multi res example (#151)
Co-authored-by: Xiaoyu Xu <[email protected]>
1 parent 48d8602 commit e469262

File tree

3 files changed

+113
-44
lines changed

3 files changed

+113
-44
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,5 @@ op_prof.csv
171171
*.lock
172172
*.png
173173
log
174+
unet_graphs
175+
*.json

examples/unet_save_and_load.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
set -eu
2+
python3 examples/unet_torch_interplay.py --save
3+
python3 examples/unet_torch_interplay.py --load

examples/unet_torch_interplay.py

Lines changed: 108 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,33 @@
2424
from tqdm import tqdm
2525

2626

27-
def mock_wrapper(f):
28-
import sys
27+
class MockCtx(object):
28+
def __enter__(self):
29+
flow.mock_torch.enable(lazy=True)
2930

30-
flow.mock_torch.enable(lazy=True)
31-
ret = f()
32-
flow.mock_torch.disable()
33-
# TODO: this trick of py mod purging will be removed
34-
tmp = sys.modules.copy()
35-
for x in tmp:
36-
if x.startswith("diffusers"):
37-
del sys.modules[x]
38-
return ret
31+
def __exit__(self, exc_type, exc_val, exc_tb):
32+
flow.mock_torch.disable()
3933

4034

41-
class UNetGraph(flow.nn.Graph):
35+
def get_unet(token):
36+
from diffusers import UNet2DConditionModel
37+
38+
unet = UNet2DConditionModel.from_pretrained(
39+
"runwayml/stable-diffusion-v1-5",
40+
use_auth_token=token,
41+
revision="fp16",
42+
torch_dtype=flow.float16,
43+
subfolder="unet",
44+
)
45+
with flow.no_grad():
46+
unet = unet.to("cuda")
47+
return unet
48+
49+
50+
class UNetGraphWithCache(flow.nn.Graph):
51+
@flow.nn.Graph.with_dynamic_input_shape(size=9)
4252
def __init__(self, unet):
43-
super().__init__()
53+
super().__init__(enable_get_runtime_state_dict=True)
4454
self.unet = unet
4555
self.config.enable_cudnn_conv_heuristic_search_algo(False)
4656
self.config.allow_fuse_add_to_output(True)
@@ -51,58 +61,108 @@ def build(self, latent_model_input, t, text_embeddings):
5161
latent_model_input, t, encoder_hidden_states=text_embeddings
5262
).sample
5363

64+
def warmup_with_arg(self, arg_meta_of_sizes):
65+
for arg_metas in arg_meta_of_sizes:
66+
print(f"warmup {arg_metas=}")
67+
arg_tensors = [flow.empty(a[0], dtype=a[1]).to("cuda") for a in arg_metas]
68+
self(*arg_tensors) # build and warmup
5469

55-
def get_graph(token):
56-
from diffusers import UNet2DConditionModel
70+
def warmup_with_load(self, file_path):
71+
state_dict = flow.load(file_path)
72+
self.load_runtime_state_dict(state_dict)
5773

58-
with flow.no_grad():
59-
unet = UNet2DConditionModel.from_pretrained(
60-
"runwayml/stable-diffusion-v1-5",
61-
use_auth_token=token,
62-
revision="fp16",
63-
torch_dtype=flow.float16,
64-
subfolder="unet",
65-
)
66-
unet = unet.to("cuda")
67-
return UNetGraph(unet)
74+
def save_graph(self, file_path):
75+
state_dict = self.runtime_state_dict()
76+
flow.save(state_dict, file_path)
77+
78+
79+
def image_dim(i):
80+
return 768 + 128 * i
81+
82+
83+
def noise_shape(batch_size, num_channels, image_w, image_h):
84+
sizes = (image_w // 8, image_h // 8)
85+
return (batch_size, num_channels) + sizes
86+
87+
88+
def get_arg_meta_of_sizes(batch_sizes, resolution_scales, num_channels):
89+
return [
90+
[
91+
(
92+
noise_shape(batch_size, num_channels, image_dim(i), image_dim(j)),
93+
flow.float16,
94+
),
95+
((1,), flow.int64),
96+
((batch_size, 77, 768), flow.float16),
97+
]
98+
for batch_size in batch_sizes
99+
for i in resolution_scales
100+
for j in resolution_scales
101+
]
68102

69103

70104
@click.command()
71105
@click.option("--token")
72-
@click.option("--repeat", default=1000)
106+
@click.option("--repeat", default=100)
73107
@click.option("--sync_interval", default=50)
74-
def benchmark(token, repeat, sync_interval):
108+
@click.option("--save", is_flag=True)
109+
@click.option("--load", is_flag=True)
110+
@click.option("--file", type=str, default="./unet_graphs")
111+
def benchmark(token, repeat, sync_interval, save, load, file):
112+
RESOLUTION_SCALES = [2, 1, 0]
113+
BATCH_SIZES = [2]
114+
# TODO: reproduce bug caused by changing batch
115+
# BATCH_SIZES = [4, 2]
116+
75117
# create a mocked unet graph
76-
unet_graph = mock_wrapper(lambda: get_graph(token))
118+
num_channels = 4
119+
120+
warmup_meta_of_sizes = get_arg_meta_of_sizes(BATCH_SIZES, RESOLUTION_SCALES, num_channels)
121+
for (i, m) in enumerate(warmup_meta_of_sizes):
122+
print(f"warmup case #{i + 1}:", m)
123+
with MockCtx():
124+
unet = get_unet(token)
125+
unet_graph = UNetGraphWithCache(unet)
126+
if load == True:
127+
print("loading graphs...")
128+
unet_graph.warmup_with_load(file)
129+
else:
130+
print("warmup with arguments...")
131+
unet_graph.warmup_with_arg(warmup_meta_of_sizes)
77132

78133
# generate inputs with torch
79134
from diffusers.utils import floats_tensor
80135
import torch
81136

82-
batch_size = 2
83-
num_channels = 4
84-
sizes = (64, 64)
85-
noise = (
86-
floats_tensor((batch_size, num_channels) + sizes).to("cuda").to(torch.float16)
87-
)
88-
print(f"{type(noise)=}")
89137
time_step = torch.tensor([10]).to("cuda")
90-
encoder_hidden_states = (
91-
floats_tensor((batch_size, 77, 768)).to("cuda").to(torch.float16)
92-
)
93-
94-
# convert to oneflow tensors
95-
[noise, time_step, encoder_hidden_states] = [
96-
flow.utils.tensor.from_torch(x)
97-
for x in [noise, time_step, encoder_hidden_states]
138+
encoder_hidden_states_of_sizes = {
139+
batch_size: floats_tensor((batch_size, 77, 768)).to("cuda").to(torch.float16)
140+
for batch_size in BATCH_SIZES
141+
}
142+
noise_of_sizes = [
143+
floats_tensor(noise_shape(batch_size, num_channels, image_dim(i), image_dim(j)))
144+
.to("cuda")
145+
.to(torch.float16)
146+
for batch_size in BATCH_SIZES
147+
for i in RESOLUTION_SCALES
148+
for j in RESOLUTION_SCALES
98149
]
99-
unet_graph(noise, time_step, encoder_hidden_states)
150+
noise_of_sizes = [flow.utils.tensor.from_torch(x) for x in noise_of_sizes]
151+
encoder_hidden_states_of_sizes = {
152+
k: flow.utils.tensor.from_torch(v) for k, v in encoder_hidden_states_of_sizes.items()
153+
}
154+
# convert to oneflow tensors
155+
time_step = flow.utils.tensor.from_torch(time_step)
100156

101157
flow._oneflow_internal.eager.Sync()
102158
import time
103159

104160
t0 = time.time()
105161
for r in tqdm(range(repeat)):
162+
import random
163+
164+
noise = random.choice(noise_of_sizes)
165+
encoder_hidden_states = encoder_hidden_states_of_sizes[noise.shape[0]]
106166
out = unet_graph(noise, time_step, encoder_hidden_states)
107167
# convert to torch tensors
108168
out = flow.utils.tensor.to_torch(out)
@@ -116,6 +176,10 @@ def benchmark(token, repeat, sync_interval):
116176
f"Finish {repeat} steps in {duration:.3f} seconds, average {throughput:.2f}it/s"
117177
)
118178

179+
if save:
180+
print("saving graphs...")
181+
unet_graph.save_graph(file)
182+
119183

120184
if __name__ == "__main__":
121185
print(f"{flow.__path__=}")

0 commit comments

Comments
 (0)