Skip to content

Commit 48d8602

Browse files
authored
Add examples running unet with torch input (#150)
1 parent 49e3946 commit 48d8602

File tree

1 file changed

+123
-0
lines changed

1 file changed

+123
-0
lines changed

examples/unet_torch_interplay.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import os
2+
3+
os.environ["ONEFLOW_MLIR_CSE"] = "1"
4+
os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1"
5+
os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1"
6+
os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1"
7+
os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1"
8+
os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1"
9+
os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1"
10+
11+
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1"
12+
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1"
13+
14+
os.environ["ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP"] = "1"
15+
os.environ["ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL"] = "1"
16+
17+
os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1"
18+
os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1"
19+
20+
os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1"
21+
22+
import click
23+
import oneflow as flow
24+
from tqdm import tqdm
25+
26+
27+
def mock_wrapper(f):
28+
import sys
29+
30+
flow.mock_torch.enable(lazy=True)
31+
ret = f()
32+
flow.mock_torch.disable()
33+
# TODO: this trick of py mod purging will be removed
34+
tmp = sys.modules.copy()
35+
for x in tmp:
36+
if x.startswith("diffusers"):
37+
del sys.modules[x]
38+
return ret
39+
40+
41+
class UNetGraph(flow.nn.Graph):
42+
def __init__(self, unet):
43+
super().__init__()
44+
self.unet = unet
45+
self.config.enable_cudnn_conv_heuristic_search_algo(False)
46+
self.config.allow_fuse_add_to_output(True)
47+
48+
def build(self, latent_model_input, t, text_embeddings):
49+
text_embeddings = flow._C.amp_white_identity(text_embeddings)
50+
return self.unet(
51+
latent_model_input, t, encoder_hidden_states=text_embeddings
52+
).sample
53+
54+
55+
def get_graph(token):
56+
from diffusers import UNet2DConditionModel
57+
58+
with flow.no_grad():
59+
unet = UNet2DConditionModel.from_pretrained(
60+
"runwayml/stable-diffusion-v1-5",
61+
use_auth_token=token,
62+
revision="fp16",
63+
torch_dtype=flow.float16,
64+
subfolder="unet",
65+
)
66+
unet = unet.to("cuda")
67+
return UNetGraph(unet)
68+
69+
70+
@click.command()
71+
@click.option("--token")
72+
@click.option("--repeat", default=1000)
73+
@click.option("--sync_interval", default=50)
74+
def benchmark(token, repeat, sync_interval):
75+
# create a mocked unet graph
76+
unet_graph = mock_wrapper(lambda: get_graph(token))
77+
78+
# generate inputs with torch
79+
from diffusers.utils import floats_tensor
80+
import torch
81+
82+
batch_size = 2
83+
num_channels = 4
84+
sizes = (64, 64)
85+
noise = (
86+
floats_tensor((batch_size, num_channels) + sizes).to("cuda").to(torch.float16)
87+
)
88+
print(f"{type(noise)=}")
89+
time_step = torch.tensor([10]).to("cuda")
90+
encoder_hidden_states = (
91+
floats_tensor((batch_size, 77, 768)).to("cuda").to(torch.float16)
92+
)
93+
94+
# convert to oneflow tensors
95+
[noise, time_step, encoder_hidden_states] = [
96+
flow.utils.tensor.from_torch(x)
97+
for x in [noise, time_step, encoder_hidden_states]
98+
]
99+
unet_graph(noise, time_step, encoder_hidden_states)
100+
101+
flow._oneflow_internal.eager.Sync()
102+
import time
103+
104+
t0 = time.time()
105+
for r in tqdm(range(repeat)):
106+
out = unet_graph(noise, time_step, encoder_hidden_states)
107+
# convert to torch tensors
108+
out = flow.utils.tensor.to_torch(out)
109+
if r == repeat - 1 or r % sync_interval == 0:
110+
flow._oneflow_internal.eager.Sync()
111+
print(f"{type(out)=}")
112+
t1 = time.time()
113+
duration = t1 - t0
114+
throughput = repeat / duration
115+
print(
116+
f"Finish {repeat} steps in {duration:.3f} seconds, average {throughput:.2f}it/s"
117+
)
118+
119+
120+
if __name__ == "__main__":
121+
print(f"{flow.__path__=}")
122+
print(f"{flow.__version__=}")
123+
benchmark()

0 commit comments

Comments
 (0)