Skip to content

[core] AnimateDiff SparseCtrl #8897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Jul 26, 2024
Merged

[core] AnimateDiff SparseCtrl #8897

merged 40 commits into from
Jul 26, 2024

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jul 18, 2024

What does this PR do?

Adds support for the long awaited AnimateDiff SparseCtrl 🚀

Paper: https://arxiv.org/abs/2311.16933
Project page: https://guoyww.github.io/projects/SparseCtrl/
Code: https://github.com/guoyww/AnimateDiff/

SparseCtrl RGB: https://huggingface.co/a-r-r-o-w/animatediff-sparsectrl-rgb
SparseCtrl Scribble: https://huggingface.co/a-r-r-o-w/animatediff-sparsectrl-scribble
SparseCtrl Motion Lora: https://huggingface.co/a-r-r-o-w/animatediff-motion-lora-v1-5-3

TODO:

  • convert scribble variant, lora
  • verify outputs with original implementation
  • show example results
  • tests
  • docs
  • move converted checkpoints to https://huggingface.co/guoyww

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@DN6 @sayakpaul

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review July 18, 2024 16:44
@a-r-r-o-w
Copy link
Member Author

Posting some example snippets and results. All input images/videos used in the code can be found here

Scribble - multiple image interpolation

Code
import torch

from diffusers import AnimateDiffSparseControlNetPipeline
from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image


model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
controlnet_id = "a-r-r-o-w/animatediff-sparsectrl-scribble"
vae_id = "stabilityai/sd-vae-ft-mse"
lora_adapter_id = "a-r-r-o-w/animatediff-motion-lora-v1-5-3"
device = "cuda:1"

motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
    model_id,
    subfolder="scheduler",
    beta_schedule="linear",
    algorithm_type="dpmsolver++",
    use_karras_sigmas=True,
)
pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
    model_id,
    motion_adapter=motion_adapter,
    controlnet=controlnet,
    vae=vae,
    scheduler=scheduler,
    torch_dtype=torch.float16,
).to(device)
pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora")
pipe.fuse_lora(lora_scale=1.0)

prompt = "an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality"
negative_prompt = "low quality, worst quality, letterboxed"

image_files = ["scribble-1.png", "scribble-2.png", "scribble-3.png"]
condition_frame_indices = [0, 8, 15]
images = [load_image(img_file) for img_file in image_files]

video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    conditioning_frames=images,
    height=512,
    width=512,
    num_inference_steps=25,
    guidance_scale=10,
    controlnet_conditioning_scale=1.0,
    controlnet_frame_indices=condition_frame_indices,
    generator=torch.Generator().manual_seed(1337),
).frames[0]
export_to_gif(video, "results/animatediff_sparsectrl_scribble_multiple.gif")

Scribble - canny inputs + low frame rate input

Code
import requests
from io import BytesIO

import imageio
import torch
from PIL import Image

from controlnet_aux.processor import Processor
from diffusers import AnimateDiffSparseControlNetPipeline
from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image


def load_video(file_path: str):
    images = []

    if file_path.startswith(('http://', 'https://')):
        # If the file_path is a URL
        response = requests.get(file_path)
        response.raise_for_status()
        content = BytesIO(response.content)
        vid = imageio.get_reader(content)
    else:
        # Assuming it's a local file path
        vid = imageio.get_reader(file_path)

    for frame in vid:
        pil_image = Image.fromarray(frame)
        images.append(pil_image)

    return images


model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
controlnet_id = "a-r-r-o-w/animatediff-sparsectrl-scribble"
vae_id = "stabilityai/sd-vae-ft-mse"
lora_adapter_id = "a-r-r-o-w/animatediff-motion-lora-v1-5-3"
device = "cuda:1"

motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
    model_id,
    subfolder="scheduler",
    beta_schedule="linear",
    algorithm_type="dpmsolver++",
    use_karras_sigmas=True,
)
pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
    model_id,
    motion_adapter=motion_adapter,
    controlnet=controlnet,
    vae=vae,
    scheduler=scheduler,
    torch_dtype=torch.float16,
).to(device)
pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora")
pipe.fuse_lora(lora_scale=1.0)

processor = Processor("canny")

prompt = "1girl, dancing, bustling metropolis, anime, vibrant colors"
negative_prompt = "low quality, worst quality"
frames = load_video("input.gif")
assert len(frames) >= 16

# get low frame rate input
condition_frame_indices = [1, 4, 8, 12, 15]
conditioning_frames = [processor(frames[i]) for i in condition_frame_indices]

video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    conditioning_frames=conditioning_frames,
    height=768,
    width=512,
    num_inference_steps=25,
    guidance_scale=10,
    controlnet_conditioning_scale=1.0,
    controlnet_frame_indices=condition_frame_indices,
    generator=torch.Generator().manual_seed(1337),
).frames[0]
export_to_gif(video, "results/animatediff_sparsectrl_scribble_low_frame_rate.gif")

Scribble - IP Adapter

Code
import torch

from diffusers import AnimateDiffSparseControlNetPipeline
from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image


model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
controlnet_id = "a-r-r-o-w/animatediff-sparsectrl-scribble"
vae_id = "stabilityai/sd-vae-ft-mse"
lora_adapter_id = "a-r-r-o-w/animatediff-motion-lora-v1-5-3"
device = "cuda:1"

motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
    model_id,
    subfolder="scheduler",
    beta_schedule="linear",
    algorithm_type="dpmsolver++",
    use_karras_sigmas=True,
)
pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
    model_id,
    motion_adapter=motion_adapter,
    controlnet=controlnet,
    vae=vae,
    scheduler=scheduler,
    torch_dtype=torch.float16,
).to(device)
pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora")
pipe.fuse_lora(lora_scale=1.0)
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipe.set_ip_adapter_scale({
    "down": {"block_2": [0.0, 1.0]},
    "up": {"block_0": [0.0, 1.0, 0.0]},
})

ip_adapter_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg")

image_files = ["scribble-1.png", "scribble-2.png", "scribble-3.png"]
condition_frame_indices = [0, 8, 15]
images = [load_image(img_file) for img_file in image_files]

video = pipe(
    prompt="drone view of buildings",
    negative_prompt="low quality, worst quality",
    conditioning_frames=images,
    ip_adapter_image=ip_adapter_image,
    height=512,
    width=512,
    num_inference_steps=25,
    guidance_scale=10,
    controlnet_conditioning_scale=1.0,
    controlnet_frame_indices=condition_frame_indices,
    generator=torch.Generator().manual_seed(1337),
).frames[0]
export_to_gif(video, "results/animatediff_sparsectrl_scribble_ipadapter.gif")

RGB - low frame rate input

Code
import torch

from diffusers import AnimateDiffSparseControlNetPipeline
from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image


model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
controlnet_id = "a-r-r-o-w/animatediff-sparsectrl-rgb"
vae_id = "stabilityai/sd-vae-ft-mse"
lora_adapter_id = "a-r-r-o-w/animatediff-motion-lora-v1-5-3"
device = "cuda:1"

motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
    model_id,
    subfolder="scheduler",
    beta_schedule="linear",
    algorithm_type="dpmsolver++",
    use_karras_sigmas=True,
)
pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
    model_id,
    motion_adapter=motion_adapter,
    controlnet=controlnet,
    vae=vae,
    scheduler=scheduler,
    torch_dtype=torch.float16,
).to(device)
pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora")

prompt = "two people holding hands in a field with wind turbines in the background"
negative_prompt = "low quality, worst quality, letterboxed"

image_files = ["lfps-1.png", "lfps-2.png", "lfps-3.png", "lfps-4.png"]
condition_frame_indices = [0, 5, 10, 15]
images = [load_image(img_file) for img_file in image_files]

video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    conditioning_frames=images,
    height=256,
    width=384,
    num_inference_steps=25,
    guidance_scale=7,
    controlnet_conditioning_scale=1.0,
    generator=torch.Generator().manual_seed(42),
).frames[0]
export_to_gif(video, "results/animatediff_sparsectrl_rgb_lfps.gif")
  • RGB - single image
Code
import torch

from diffusers import AnimateDiffSparseControlNetPipeline
from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image


model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
controlnet_id = "a-r-r-o-w/animatediff-sparsectrl-rgb"
vae_id = "stabilityai/sd-vae-ft-mse"
lora_adapter_id = "a-r-r-o-w/animatediff-motion-lora-v1-5-3"
device = "cuda:1"

motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
    model_id,
    subfolder="scheduler",
    beta_schedule="linear",
    algorithm_type="dpmsolver++",
    use_karras_sigmas=True,
)
pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
    model_id,
    motion_adapter=motion_adapter,
    controlnet=controlnet,
    vae=vae,
    scheduler=scheduler,
    torch_dtype=torch.float16,
).to(device)
pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora")

prompt = "closeup face photo of man in black clothes, night city street, bokeh, fireworks in background"
negative_prompt = "low quality, worst quality"
image = load_image("firework.png")

video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    conditioning_frames=image,
    height=512,
    width=512,
    num_inference_steps=25,
    guidance_scale=8.5,
    controlnet_conditioning_scale=1.0,
    generator=torch.Generator().manual_seed(42),
).frames[0]
export_to_gif(video, "results/animatediff_sparsectrl_rgb_single.gif")

I haven't been able to fully replicate SparseCtrl RGB at the moment, I think, as can be seen from the weird outputs compared to the original demos. However, if you look at the original demos closely, we can observe this sudden flicker behaviour too, which could mean the implementation is okay but poor model quality. There are some reports about this so it might be quite possible that this is expected, and what we're observing is the RGB model not being too good: guoyww/AnimateDiff#274 and guoyww/AnimateDiff#278.

Comment on lines 68 to 70
api = HfApi()
repo_id = api.create_repo(args.output_path, exist_ok=True).repo_id
api.upload_folder(repo_id=repo_id, folder_path=args.output_path, repo_type="model")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can directly use create_repo() and upload_folder() from huggingface_hub. We don't need to use HfApi() here.

Comment on lines 24 to 35
## Example for loading SparseControlNetModel

```py
import torch
from diffusers import SparseControlNetModel

# fp32 variant in float16
controlnet = SparseControlNetModel.from_pretrained("a-r-r-o-w/animatediff-sparsectrl-rgb", torch_dtype=torch.float16)

# fp16 variant
controlnet = SparseControlNetModel.from_pretrained("a-r-r-o-w/animatediff-sparsectrl-rgb", variant="fp16", torch_dtype=torch.float16)
```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a full-fledged example? I would prefer an end-to-end example here.

Also, I think we should move the "animatediff-sparsectrl-rgb" checkpoints to their appropriate orgs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a full-fledged example? I would prefer an end-to-end example here.

This is just the model documentation, no? I've noticed we have simple examples like this for model docs and a full-fledged example in pipeline docs. But okay to add a full example here as well.

Also, I think we should move the "animatediff-sparsectrl-rgb" checkpoints to their appropriate orgs.

Yes, will need help with this. There are 3 checkpoints that will need to be moved

for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)

def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we are not doing any checkpointing in the forward(). Perhaps that needs to be implemented?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will take a look at this. Some of the code here is remnants from copying the ControlNetModel

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems correct and we don't have to make any changes in the forward here if I follow the implementation of ControlNetModel as the source of truth. The DownBlockMotion, CrossAttnDownBlockMotion and UNetMidBlock2DCrossAttn have it implemented.

Comment on lines 203 to 215
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
addition_embed_type_num_heads: int = 64,
motion_max_seq_length: int = 32,
motion_num_attention_heads: int = 8,
concat_conditioning_mask: bool = True,
use_simplified_condition_embedding: bool = True,
Copy link
Member

@sayakpaul sayakpaul Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to have all the config variables here? Can we reduce some of them?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I left high-level comments. For a deeper review, I will defer to @DN6. I can then review one more time after that.

@@ -58,6 +58,7 @@ def get_down_block(
resnet_time_scale_shift: str = "default",
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_double_self_attention: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this argument from here. I would import the blocks directly into the model and set the argument directly in the block.

@@ -1574,7 +1581,7 @@ def __init__(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
num_layers=temporal_transformer_layers_per_block[i],
norm_num_groups=temporal_norm_num_groups,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch here. Can we just remove the temporal_norm_num_groups argument from the UpBlockMotion class

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this in #8846? That's where this change comes from in order for the tests with smaller unet/motion modeling components to pass

@a-r-r-o-w
Copy link
Member Author

I've updated the documentation images with links from the Hub.

@DN6 The checkpoints are ready to be moved. Updated their model cards:

@a-r-r-o-w
Copy link
Member Author

Phew! the tests have passed - thanks for bearing with me on this one :)

@DN6 DN6 merged commit 5c53ca5 into main Jul 26, 2024
18 checks passed
@a-r-r-o-w a-r-r-o-w deleted the animatediff/sparsectrl branch July 26, 2024 12:16
@a-r-r-o-w a-r-r-o-w mentioned this pull request Jul 28, 2024
10 tasks
@eps696
Copy link

eps696 commented Jul 29, 2024

@a-r-r-o-w thanks for the response!
alas it still does not seem to handle properly the input image, giving similar outputs no matter what.
the code is copied from "RGB single image" here #8897 (comment) , with latest diffusers from the github main branch
animatediff_sparsectrl_rgb_single
animatediff_sparsectrl_rgb_single-2

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jul 29, 2024

@eps696 Have you tried these examples with the original AnimateDiff code base? The demos they show seem to have the first frame being a different darker color than remaining frames, and there are some reports as well: guoyww/AnimateDiff#278 and guoyww/AnimateDiff#274.

But I find it bizarre that it is not taking into account the input image at all and only following the prompt. Will debug this to find differences from our implementation vs theirs. It'd be of great help if you could check out their repo and see if you get the same incorrect behaviour like here since I recently deleted their codebase and do not have it setup at the moment + other priorities.

Also, the problem seems to come only from the RGB checkpoint, yes? For Scribble, there shouldn't be any problems hopefully but please report if you find any

@eps696
Copy link

eps696 commented Jul 29, 2024

@a-r-r-o-w the difference in the keyframe color is known issue, it happened with other implementation too, maybe it's a model problem indeed. yet diffusers implementation seems to miss somehow overall motion propagation (from the keyframe over the video), which i didn't see elsewhere.

i'll check the original implementation; for now here is tests from comfyui (with the same inputs and settings):
ad_00040a
ad_00040b

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jul 29, 2024

@chuck-ma You mentioned here that you found the cause of first frame being different color than rest - could you please share what changes you made?

@a-r-r-o-w
Copy link
Member Author

@a-r-r-o-w the difference in the keyframe color is known issue, it happened with other implementation too, maybe it's a model problem indeed. yet diffusers implementation seems to miss somehow overall motion propagation (from the keyframe over the video), which i didn't see elsewhere.

okay i might have overlooked some implementation details if these are the comfy results 😶‍🌫️ i will take a better look soon!

@eps696
Copy link

eps696 commented Jul 29, 2024

@a-r-r-o-w just to confirm that the original AnimateDiff code also works as expected:
eps

source

@aihopper
Copy link

aihopper commented Sep 2, 2024

How high is fixing this issue in your priority list? I'd be helpful to know what to expect so we can get aligned with your roadmap.

@a-r-r-o-w
Copy link
Member Author

How high is fixing this issue in your priority list? I'd be helpful to know what to expect so we can get aligned with your roadmap.

It was high on my priority list when this was added and I gave it some time trying to look at numerical differences. I wasn't able to figure out what's wrong though, and now other important things are on my plate at the moment. Would be awesome if you'd like to take a look but I have a few more CogVideoX follow-ups before I can get back to this. I think only RGB SparseCtrl is broken, so at best it is just probably some simple internal state in the controlnet I'm not using correctly.

sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* initial sparse control model draft

* remove unnecessary implementation

* copy animatediff pipeline

* remove deprecated callbacks

* update

* update pipeline implementation progress

* make style

* make fix-copies

* update progress

* add partially working pipeline

* remove debug prints

* add model docs

* dummy objects

* improve motion lora conversion script

* fix bugs

* update docstrings

* remove unnecessary model params; docs

* address review comment

* add copied from to zero_module

* copy animatediff test

* add fast tests

* update docs

* update

* update pipeline docs

* fix expected slice values

* fix license

* remove get_down_block usage

* remove temporal_double_self_attention from get_down_block

* update

* update docs with org and documentation images

* make from_unet work in sparsecontrolnetmodel

* add latest freeinit test from #8969

* make fix-copies

* LoraLoaderMixin -> StableDiffsuionLoraLoaderMixin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants