Skip to content

[core] AnimateDiff SparseCtrl #8897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
1b9d59d
initial sparse control model draft
a-r-r-o-w Jul 17, 2024
b12a0cb
remove unnecessary implementation
a-r-r-o-w Jul 17, 2024
b3d3d4c
copy animatediff pipeline
a-r-r-o-w Jul 17, 2024
b609691
remove deprecated callbacks
a-r-r-o-w Jul 18, 2024
df7792f
update
a-r-r-o-w Jul 18, 2024
bd587b6
update pipeline implementation progress
a-r-r-o-w Jul 18, 2024
9dbc555
make style
a-r-r-o-w Jul 18, 2024
927b572
make fix-copies
a-r-r-o-w Jul 18, 2024
713ef0a
update progress
a-r-r-o-w Jul 18, 2024
6a81908
add partially working pipeline
a-r-r-o-w Jul 18, 2024
6485d7f
remove debug prints
a-r-r-o-w Jul 18, 2024
de3a9d7
add model docs
a-r-r-o-w Jul 18, 2024
3227bcf
dummy objects
a-r-r-o-w Jul 18, 2024
4ccbfba
improve motion lora conversion script
a-r-r-o-w Jul 18, 2024
9ffb01b
fix bugs
a-r-r-o-w Jul 18, 2024
ede45dc
update docstrings
a-r-r-o-w Jul 18, 2024
be8667d
Merge branch 'main' into animatediff/sparsectrl
a-r-r-o-w Jul 18, 2024
f9f0a1e
remove unnecessary model params; docs
a-r-r-o-w Jul 19, 2024
ed8cd72
address review comment
a-r-r-o-w Jul 19, 2024
4b6840d
add copied from to zero_module
a-r-r-o-w Jul 19, 2024
e10a1c8
copy animatediff test
a-r-r-o-w Jul 19, 2024
a4d6403
add fast tests
a-r-r-o-w Jul 22, 2024
1ea9462
update docs
a-r-r-o-w Jul 22, 2024
1475c74
Merge branch 'main' into animatediff/sparsectrl
a-r-r-o-w Jul 22, 2024
c7cbcc5
update
a-r-r-o-w Jul 22, 2024
f382a0f
update pipeline docs
a-r-r-o-w Jul 22, 2024
23c5893
Merge branch 'main' into animatediff/sparsectrl
a-r-r-o-w Jul 22, 2024
777fdf5
fix expected slice values
a-r-r-o-w Jul 22, 2024
987228a
fix license
a-r-r-o-w Jul 23, 2024
0540736
remove get_down_block usage
a-r-r-o-w Jul 23, 2024
bed8e39
remove temporal_double_self_attention from get_down_block
a-r-r-o-w Jul 23, 2024
e6e6733
update
a-r-r-o-w Jul 23, 2024
dca1599
Merge branch 'main' into animatediff/sparsectrl
a-r-r-o-w Jul 23, 2024
c367b7a
update docs with org and documentation images
a-r-r-o-w Jul 23, 2024
a9ad70f
make from_unet work in sparsecontrolnetmodel
a-r-r-o-w Jul 23, 2024
444725d
Merge branch 'main' into animatediff/sparsectrl
a-r-r-o-w Jul 23, 2024
3cfd4fd
Merge branch 'main' into animatediff/sparsectrl
a-r-r-o-w Jul 26, 2024
56f1cc4
add latest freeinit test from #8969
a-r-r-o-w Jul 26, 2024
8e37af0
make fix-copies
a-r-r-o-w Jul 26, 2024
6b0d053
LoraLoaderMixin -> StableDiffsuionLoraLoaderMixin
a-r-r-o-w Jul 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@
title: HunyuanDiT2DControlNetModel
- local: api/models/controlnet_sd3
title: SD3ControlNetModel
- local: api/models/controlnet_sparsectrl
title: SparseControlNetModel
title: Models
- isExpanded: false
sections:
Expand Down
46 changes: 46 additions & 0 deletions docs/source/en/api/models/controlnet_sparsectrl.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# SparseControlNetModel

SparseControlNetModel is an implementation of ControlNet for [AnimateDiff](https://arxiv.org/abs/2307.04725).

ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.

The SparseCtrl version of ControlNet was introduced in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models](https://arxiv.org/abs/2311.16933) for achieving controlled generation in text-to-video diffusion models by Yuwei Guo, Ceyuan Yang, Anyi Rao, Maneesh Agrawala, Dahua Lin, and Bo Dai.

The abstract from the paper is:

*The development of text-to-video (T2V), i.e., generating videos with a given text prompt, has been significantly advanced in recent years. However, relying solely on text prompts often results in ambiguous frame composition due to spatial uncertainty. The research community thus leverages the dense structure signals, e.g., per-frame depth/edge sequences, to enhance controllability, whose collection accordingly increases the burden of inference. In this work, we present SparseCtrl to enable flexible structure control with temporally sparse signals, requiring only one or a few inputs, as shown in Figure 1. It incorporates an additional condition encoder to process these sparse signals while leaving the pre-trained T2V model untouched. The proposed approach is compatible with various modalities, including sketches, depth maps, and RGB images, providing more practical control for video generation and promoting applications such as storyboarding, depth rendering, keyframe animation, and interpolation. Extensive experiments demonstrate the generalization of SparseCtrl on both original and personalized T2V generators. Codes and models will be publicly available at [this https URL](https://guoyww.github.io/projects/SparseCtrl).*

## Example for loading SparseControlNetModel

```python
import torch
from diffusers import SparseControlNetModel

# fp32 variant in float16
# 1. Scribble checkpoint
controlnet = SparseControlNetModel.from_pretrained("guoyww/animatediff-sparsectrl-scribble", torch_dtype=torch.float16)

# 2. RGB checkpoint
controlnet = SparseControlNetModel.from_pretrained("guoyww/animatediff-sparsectrl-rgb", torch_dtype=torch.float16)

# For loading fp16 variant, pass `variant="fp16"` as an additional parameter
```

## SparseControlNetModel

[[autodoc]] SparseControlNetModel

## SparseControlNetOutput

[[autodoc]] models.controlnet_sparsectrl.SparseControlNetOutput
190 changes: 189 additions & 1 deletion docs/source/en/api/pipelines/animatediff.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,189 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you

</Tip>

### AnimateDiffSparseControlNetPipeline

[SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models](https://arxiv.org/abs/2311.16933) for achieving controlled generation in text-to-video diffusion models by Yuwei Guo, Ceyuan Yang, Anyi Rao, Maneesh Agrawala, Dahua Lin, and Bo Dai.

The abstract from the paper is:

*The development of text-to-video (T2V), i.e., generating videos with a given text prompt, has been significantly advanced in recent years. However, relying solely on text prompts often results in ambiguous frame composition due to spatial uncertainty. The research community thus leverages the dense structure signals, e.g., per-frame depth/edge sequences, to enhance controllability, whose collection accordingly increases the burden of inference. In this work, we present SparseCtrl to enable flexible structure control with temporally sparse signals, requiring only one or a few inputs, as shown in Figure 1. It incorporates an additional condition encoder to process these sparse signals while leaving the pre-trained T2V model untouched. The proposed approach is compatible with various modalities, including sketches, depth maps, and RGB images, providing more practical control for video generation and promoting applications such as storyboarding, depth rendering, keyframe animation, and interpolation. Extensive experiments demonstrate the generalization of SparseCtrl on both original and personalized T2V generators. Codes and models will be publicly available at [this https URL](https://guoyww.github.io/projects/SparseCtrl).*

SparseCtrl introduces the following checkpoints for controlled text-to-video generation:

- [SparseCtrl Scribble](https://huggingface.co/guoyww/animatediff-sparsectrl-scribble)
- [SparseCtrl RGB](https://huggingface.co/guoyww/animatediff-sparsectrl-rgb)

#### Using SparseCtrl Scribble

```python
import torch

from diffusers import AnimateDiffSparseControlNetPipeline
from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image


model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
controlnet_id = "guoyww/animatediff-sparsectrl-scribble"
lora_adapter_id = "guoyww/animatediff-motion-lora-v1-5-3"
vae_id = "stabilityai/sd-vae-ft-mse"
device = "cuda"

motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
beta_schedule="linear",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
model_id,
motion_adapter=motion_adapter,
controlnet=controlnet,
vae=vae,
scheduler=scheduler,
torch_dtype=torch.float16,
).to(device)
pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora")
pipe.fuse_lora(lora_scale=1.0)

prompt = "an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality"
negative_prompt = "low quality, worst quality, letterboxed"

image_files = [
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-2.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-3.png"
]
condition_frame_indices = [0, 8, 15]
conditioning_frames = [load_image(img_file) for img_file in image_files]

video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=25,
conditioning_frames=conditioning_frames,
controlnet_conditioning_scale=1.0,
controlnet_frame_indices=condition_frame_indices,
generator=torch.Generator().manual_seed(1337),
).frames[0]
export_to_gif(video, "output.gif")
```

Here are some sample outputs:

<table align="center">
<tr>
<center>
<b>an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality</b>
</center>
</tr>
<tr>
<td>
<center>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-1.png" alt="scribble-1" />
</center>
</td>
<td>
<center>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-2.png" alt="scribble-2" />
</center>
</td>
<td>
<center>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-3.png" alt="scribble-3" />
</center>
</td>
</tr>
<tr>
<td colspan=3>
<center>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-sparsectrl-scribble-results.gif" alt="an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality" />
</center>
</td>
</tr>
</table>

#### Using SparseCtrl RGB

```python
import torch

from diffusers import AnimateDiffSparseControlNetPipeline
from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image


model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
controlnet_id = "guoyww/animatediff-sparsectrl-rgb"
lora_adapter_id = "guoyww/animatediff-motion-lora-v1-5-3"
vae_id = "stabilityai/sd-vae-ft-mse"
device = "cuda"

motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
beta_schedule="linear",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
model_id,
motion_adapter=motion_adapter,
controlnet=controlnet,
vae=vae,
scheduler=scheduler,
torch_dtype=torch.float16,
).to(device)
pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora")

image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-firework.png")

video = pipe(
prompt="closeup face photo of man in black clothes, night city street, bokeh, fireworks in background",
negative_prompt="low quality, worst quality",
num_inference_steps=25,
conditioning_frames=image,
controlnet_frame_indices=[0],
controlnet_conditioning_scale=1.0,
generator=torch.Generator().manual_seed(42),
).frames[0]
export_to_gif(video, "output.gif")
```

Here are some sample outputs:

<table align="center">
<tr>
<center>
<b>closeup face photo of man in black clothes, night city street, bokeh, fireworks in background</b>
</center>
</tr>
<tr>
<td>
<center>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-firework.png" alt="closeup face photo of man in black clothes, night city street, bokeh, fireworks in background" />
</center>
</td>
<td>
<center>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-sparsectrl-rgb-result.gif" alt="closeup face photo of man in black clothes, night city street, bokeh, fireworks in background" />
</center>
</td>
</tr>
</table>

### AnimateDiffSDXLPipeline

AnimateDiff can also be used with SDXL models. This is currently an experimental feature as only a beta release of the motion adapter checkpoint is available.
Expand Down Expand Up @@ -571,7 +754,6 @@ ckpt_path = "https://huggingface.co/Lightricks/LongAnimateDiff/blob/main/lt_long

adapter = MotionAdapter.from_single_file(ckpt_path, torch_dtype=torch.float16)
pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter)

```

## AnimateDiffPipeline
Expand All @@ -580,6 +762,12 @@ pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapt
- all
- __call__

## AnimateDiffSparseControlNetPipeline

[[autodoc]] AnimateDiffSparseControlNetPipeline
- all
- __call__

## AnimateDiffSDXLPipeline

[[autodoc]] AnimateDiffSDXLPipeline
Expand Down
21 changes: 18 additions & 3 deletions scripts/convert_animatediff_motion_lora_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
import os

import torch
from huggingface_hub import create_repo, upload_folder
from safetensors.torch import load_file, save_file


Expand All @@ -25,8 +27,14 @@ def convert_motion_module(original_state_dict):

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--ckpt_path", type=str, required=True, help="Path to checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path to output directory")
parser.add_argument(
"--push_to_hub",
action="store_true",
default=False,
help="Whether to push the converted model to the HF or not",
)

return parser.parse_args()

Expand All @@ -51,4 +59,11 @@ def get_args():
continue
output_dict.update({f"unet.{module_name}": params})

save_file(output_dict, f"{args.output_path}/diffusion_pytorch_model.safetensors")
os.makedirs(args.output_path, exist_ok=True)

filepath = os.path.join(args.output_path, "diffusion_pytorch_model.safetensors")
save_file(output_dict, filepath)

if args.push_to_hub:
repo_id = create_repo(args.output_path, exist_ok=True).repo_id
upload_folder(repo_id=repo_id, folder_path=args.output_path, repo_type="model")
83 changes: 83 additions & 0 deletions scripts/convert_animatediff_sparsectrl_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import argparse
from typing import Dict

import torch
import torch.nn as nn

from diffusers import SparseControlNetModel


KEYS_RENAME_MAPPING = {
".attention_blocks.0": ".attn1",
".attention_blocks.1": ".attn2",
".attn1.pos_encoder": ".pos_embed",
".ff_norm": ".norm3",
".norms.0": ".norm1",
".norms.1": ".norm2",
".temporal_transformer": "",
}


def convert(original_state_dict: Dict[str, nn.Module]) -> Dict[str, nn.Module]:
converted_state_dict = {}

for key in list(original_state_dict.keys()):
renamed_key = key
for new_name, old_name in KEYS_RENAME_MAPPING.items():
renamed_key = renamed_key.replace(new_name, old_name)
converted_state_dict[renamed_key] = original_state_dict.pop(key)

return converted_state_dict


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True, help="Path to checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path to output directory")
parser.add_argument(
"--max_motion_seq_length",
type=int,
default=32,
help="Max motion sequence length supported by the motion adapter",
)
parser.add_argument(
"--conditioning_channels", type=int, default=4, help="Number of channels in conditioning input to controlnet"
)
parser.add_argument(
"--use_simplified_condition_embedding",
action="store_true",
default=False,
help="Whether or not to use simplified condition embedding. When `conditioning_channels==4` i.e. latent inputs, set this to `True`. When `conditioning_channels==3` i.e. image inputs, set this to `False`",
)
parser.add_argument(
"--save_fp16",
action="store_true",
default=False,
help="Whether or not to save model in fp16 precision along with fp32",
)
parser.add_argument(
"--push_to_hub", action="store_true", default=False, help="Whether or not to push saved model to the HF hub"
)
return parser.parse_args()


if __name__ == "__main__":
args = get_args()

state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict.keys():
state_dict: dict = state_dict["state_dict"]

controlnet = SparseControlNetModel(
conditioning_channels=args.conditioning_channels,
motion_max_seq_length=args.max_motion_seq_length,
use_simplified_condition_embedding=args.use_simplified_condition_embedding,
)

state_dict = convert(state_dict)
controlnet.load_state_dict(state_dict, strict=True)

controlnet.save_pretrained(args.output_path, push_to_hub=args.push_to_hub)
if args.save_fp16:
controlnet = controlnet.to(dtype=torch.float16)
controlnet.save_pretrained(args.output_path, variant="fp16", push_to_hub=args.push_to_hub)
Loading
Loading