-
Notifications
You must be signed in to change notification settings - Fork 6k
[core] AnimateDiff SparseCtrl #8897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Posting some example snippets and results. All input images/videos used in the code can be found here Scribble - multiple image interpolation Codeimport torch
from diffusers import AnimateDiffSparseControlNetPipeline
from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
controlnet_id = "a-r-r-o-w/animatediff-sparsectrl-scribble"
vae_id = "stabilityai/sd-vae-ft-mse"
lora_adapter_id = "a-r-r-o-w/animatediff-motion-lora-v1-5-3"
device = "cuda:1"
motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
beta_schedule="linear",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
model_id,
motion_adapter=motion_adapter,
controlnet=controlnet,
vae=vae,
scheduler=scheduler,
torch_dtype=torch.float16,
).to(device)
pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora")
pipe.fuse_lora(lora_scale=1.0)
prompt = "an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality"
negative_prompt = "low quality, worst quality, letterboxed"
image_files = ["scribble-1.png", "scribble-2.png", "scribble-3.png"]
condition_frame_indices = [0, 8, 15]
images = [load_image(img_file) for img_file in image_files]
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
conditioning_frames=images,
height=512,
width=512,
num_inference_steps=25,
guidance_scale=10,
controlnet_conditioning_scale=1.0,
controlnet_frame_indices=condition_frame_indices,
generator=torch.Generator().manual_seed(1337),
).frames[0]
export_to_gif(video, "results/animatediff_sparsectrl_scribble_multiple.gif")
Scribble - canny inputs + low frame rate input Codeimport requests
from io import BytesIO
import imageio
import torch
from PIL import Image
from controlnet_aux.processor import Processor
from diffusers import AnimateDiffSparseControlNetPipeline
from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image
def load_video(file_path: str):
images = []
if file_path.startswith(('http://', 'https://')):
# If the file_path is a URL
response = requests.get(file_path)
response.raise_for_status()
content = BytesIO(response.content)
vid = imageio.get_reader(content)
else:
# Assuming it's a local file path
vid = imageio.get_reader(file_path)
for frame in vid:
pil_image = Image.fromarray(frame)
images.append(pil_image)
return images
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
controlnet_id = "a-r-r-o-w/animatediff-sparsectrl-scribble"
vae_id = "stabilityai/sd-vae-ft-mse"
lora_adapter_id = "a-r-r-o-w/animatediff-motion-lora-v1-5-3"
device = "cuda:1"
motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
beta_schedule="linear",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
model_id,
motion_adapter=motion_adapter,
controlnet=controlnet,
vae=vae,
scheduler=scheduler,
torch_dtype=torch.float16,
).to(device)
pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora")
pipe.fuse_lora(lora_scale=1.0)
processor = Processor("canny")
prompt = "1girl, dancing, bustling metropolis, anime, vibrant colors"
negative_prompt = "low quality, worst quality"
frames = load_video("input.gif")
assert len(frames) >= 16
# get low frame rate input
condition_frame_indices = [1, 4, 8, 12, 15]
conditioning_frames = [processor(frames[i]) for i in condition_frame_indices]
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
conditioning_frames=conditioning_frames,
height=768,
width=512,
num_inference_steps=25,
guidance_scale=10,
controlnet_conditioning_scale=1.0,
controlnet_frame_indices=condition_frame_indices,
generator=torch.Generator().manual_seed(1337),
).frames[0]
export_to_gif(video, "results/animatediff_sparsectrl_scribble_low_frame_rate.gif")
Scribble - IP Adapter Codeimport torch
from diffusers import AnimateDiffSparseControlNetPipeline
from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
controlnet_id = "a-r-r-o-w/animatediff-sparsectrl-scribble"
vae_id = "stabilityai/sd-vae-ft-mse"
lora_adapter_id = "a-r-r-o-w/animatediff-motion-lora-v1-5-3"
device = "cuda:1"
motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
beta_schedule="linear",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
model_id,
motion_adapter=motion_adapter,
controlnet=controlnet,
vae=vae,
scheduler=scheduler,
torch_dtype=torch.float16,
).to(device)
pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora")
pipe.fuse_lora(lora_scale=1.0)
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipe.set_ip_adapter_scale({
"down": {"block_2": [0.0, 1.0]},
"up": {"block_0": [0.0, 1.0, 0.0]},
})
ip_adapter_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg")
image_files = ["scribble-1.png", "scribble-2.png", "scribble-3.png"]
condition_frame_indices = [0, 8, 15]
images = [load_image(img_file) for img_file in image_files]
video = pipe(
prompt="drone view of buildings",
negative_prompt="low quality, worst quality",
conditioning_frames=images,
ip_adapter_image=ip_adapter_image,
height=512,
width=512,
num_inference_steps=25,
guidance_scale=10,
controlnet_conditioning_scale=1.0,
controlnet_frame_indices=condition_frame_indices,
generator=torch.Generator().manual_seed(1337),
).frames[0]
export_to_gif(video, "results/animatediff_sparsectrl_scribble_ipadapter.gif")
RGB - low frame rate input Codeimport torch
from diffusers import AnimateDiffSparseControlNetPipeline
from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
controlnet_id = "a-r-r-o-w/animatediff-sparsectrl-rgb"
vae_id = "stabilityai/sd-vae-ft-mse"
lora_adapter_id = "a-r-r-o-w/animatediff-motion-lora-v1-5-3"
device = "cuda:1"
motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
beta_schedule="linear",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
model_id,
motion_adapter=motion_adapter,
controlnet=controlnet,
vae=vae,
scheduler=scheduler,
torch_dtype=torch.float16,
).to(device)
pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora")
prompt = "two people holding hands in a field with wind turbines in the background"
negative_prompt = "low quality, worst quality, letterboxed"
image_files = ["lfps-1.png", "lfps-2.png", "lfps-3.png", "lfps-4.png"]
condition_frame_indices = [0, 5, 10, 15]
images = [load_image(img_file) for img_file in image_files]
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
conditioning_frames=images,
height=256,
width=384,
num_inference_steps=25,
guidance_scale=7,
controlnet_conditioning_scale=1.0,
generator=torch.Generator().manual_seed(42),
).frames[0]
export_to_gif(video, "results/animatediff_sparsectrl_rgb_lfps.gif")
Codeimport torch
from diffusers import AnimateDiffSparseControlNetPipeline
from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3"
controlnet_id = "a-r-r-o-w/animatediff-sparsectrl-rgb"
vae_id = "stabilityai/sd-vae-ft-mse"
lora_adapter_id = "a-r-r-o-w/animatediff-motion-lora-v1-5-3"
device = "cuda:1"
motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)
controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
beta_schedule="linear",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
model_id,
motion_adapter=motion_adapter,
controlnet=controlnet,
vae=vae,
scheduler=scheduler,
torch_dtype=torch.float16,
).to(device)
pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora")
prompt = "closeup face photo of man in black clothes, night city street, bokeh, fireworks in background"
negative_prompt = "low quality, worst quality"
image = load_image("firework.png")
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
conditioning_frames=image,
height=512,
width=512,
num_inference_steps=25,
guidance_scale=8.5,
controlnet_conditioning_scale=1.0,
generator=torch.Generator().manual_seed(42),
).frames[0]
export_to_gif(video, "results/animatediff_sparsectrl_rgb_single.gif")
I haven't been able to fully replicate SparseCtrl RGB at the moment, I think, as can be seen from the weird outputs compared to the original demos. However, if you look at the original demos closely, we can observe this sudden flicker behaviour too, which could mean the implementation is okay but poor model quality. There are some reports about this so it might be quite possible that this is expected, and what we're observing is the RGB model not being too good: guoyww/AnimateDiff#274 and guoyww/AnimateDiff#278. |
api = HfApi() | ||
repo_id = api.create_repo(args.output_path, exist_ok=True).repo_id | ||
api.upload_folder(repo_id=repo_id, folder_path=args.output_path, repo_type="model") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can directly use create_repo()
and upload_folder()
from huggingface_hub
. We don't need to use HfApi()
here.
## Example for loading SparseControlNetModel | ||
|
||
```py | ||
import torch | ||
from diffusers import SparseControlNetModel | ||
|
||
# fp32 variant in float16 | ||
controlnet = SparseControlNetModel.from_pretrained("a-r-r-o-w/animatediff-sparsectrl-rgb", torch_dtype=torch.float16) | ||
|
||
# fp16 variant | ||
controlnet = SparseControlNetModel.from_pretrained("a-r-r-o-w/animatediff-sparsectrl-rgb", variant="fp16", torch_dtype=torch.float16) | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a full-fledged example? I would prefer an end-to-end example here.
Also, I think we should move the "animatediff-sparsectrl-rgb" checkpoints to their appropriate orgs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a full-fledged example? I would prefer an end-to-end example here.
This is just the model documentation, no? I've noticed we have simple examples like this for model docs and a full-fledged example in pipeline docs. But okay to add a full example here as well.
Also, I think we should move the "animatediff-sparsectrl-rgb" checkpoints to their appropriate orgs.
Yes, will need help with this. There are 3 checkpoints that will need to be moved
for module in self.children(): | ||
fn_recursive_set_attention_slice(module, reversed_slice_size) | ||
|
||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we are not doing any checkpointing in the forward()
. Perhaps that needs to be implemented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will take a look at this. Some of the code here is remnants from copying the ControlNetModel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems correct and we don't have to make any changes in the forward here if I follow the implementation of ControlNetModel
as the source of truth. The DownBlockMotion, CrossAttnDownBlockMotion and UNetMidBlock2DCrossAttn have it implemented.
addition_time_embed_dim: Optional[int] = None, | ||
num_class_embeds: Optional[int] = None, | ||
upcast_attention: bool = False, | ||
resnet_time_scale_shift: str = "default", | ||
projection_class_embeddings_input_dim: Optional[int] = None, | ||
controlnet_conditioning_channel_order: str = "rgb", | ||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), | ||
global_pool_conditions: bool = False, | ||
addition_embed_type_num_heads: int = 64, | ||
motion_max_seq_length: int = 32, | ||
motion_num_attention_heads: int = 8, | ||
concat_conditioning_mask: bool = True, | ||
use_simplified_condition_embedding: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to have all the config variables here? Can we reduce some of them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I left high-level comments. For a deeper review, I will defer to @DN6. I can then review one more time after that.
@@ -58,6 +58,7 @@ def get_down_block( | |||
resnet_time_scale_shift: str = "default", | |||
temporal_num_attention_heads: int = 8, | |||
temporal_max_seq_length: int = 32, | |||
temporal_double_self_attention: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove this argument from here. I would import the blocks directly into the model and set the argument directly in the block.
@@ -1574,7 +1581,7 @@ def __init__( | |||
num_attention_heads=temporal_num_attention_heads, | |||
in_channels=out_channels, | |||
num_layers=temporal_transformer_layers_per_block[i], | |||
norm_num_groups=temporal_norm_num_groups, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch here. Can we just remove the temporal_norm_num_groups
argument from the UpBlockMotion
class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do this in #8846? That's where this change comes from in order for the tests with smaller unet/motion modeling components to pass
I've updated the documentation images with links from the Hub. @DN6 The checkpoints are ready to be moved. Updated their model cards:
|
Phew! the tests have passed - thanks for bearing with me on this one :) |
@a-r-r-o-w thanks for the response! |
@eps696 Have you tried these examples with the original AnimateDiff code base? The demos they show seem to have the first frame being a different darker color than remaining frames, and there are some reports as well: guoyww/AnimateDiff#278 and guoyww/AnimateDiff#274. But I find it bizarre that it is not taking into account the input image at all and only following the prompt. Will debug this to find differences from our implementation vs theirs. It'd be of great help if you could check out their repo and see if you get the same incorrect behaviour like here since I recently deleted their codebase and do not have it setup at the moment + other priorities. Also, the problem seems to come only from the RGB checkpoint, yes? For Scribble, there shouldn't be any problems hopefully but please report if you find any |
@a-r-r-o-w the difference in the keyframe color is known issue, it happened with other implementation too, maybe it's a model problem indeed. yet diffusers implementation seems to miss somehow overall motion propagation (from the keyframe over the video), which i didn't see elsewhere. i'll check the original implementation; for now here is tests from comfyui (with the same inputs and settings): |
okay i might have overlooked some implementation details if these are the comfy results 😶🌫️ i will take a better look soon! |
@a-r-r-o-w just to confirm that the original AnimateDiff code also works as expected: |
How high is fixing this issue in your priority list? I'd be helpful to know what to expect so we can get aligned with your roadmap. |
It was high on my priority list when this was added and I gave it some time trying to look at numerical differences. I wasn't able to figure out what's wrong though, and now other important things are on my plate at the moment. Would be awesome if you'd like to take a look but I have a few more CogVideoX follow-ups before I can get back to this. I think only RGB SparseCtrl is broken, so at best it is just probably some simple internal state in the controlnet I'm not using correctly. |
* initial sparse control model draft * remove unnecessary implementation * copy animatediff pipeline * remove deprecated callbacks * update * update pipeline implementation progress * make style * make fix-copies * update progress * add partially working pipeline * remove debug prints * add model docs * dummy objects * improve motion lora conversion script * fix bugs * update docstrings * remove unnecessary model params; docs * address review comment * add copied from to zero_module * copy animatediff test * add fast tests * update docs * update * update pipeline docs * fix expected slice values * fix license * remove get_down_block usage * remove temporal_double_self_attention from get_down_block * update * update docs with org and documentation images * make from_unet work in sparsecontrolnetmodel * add latest freeinit test from #8969 * make fix-copies * LoraLoaderMixin -> StableDiffsuionLoraLoaderMixin
What does this PR do?
Adds support for the long awaited AnimateDiff SparseCtrl 🚀
Paper: https://arxiv.org/abs/2311.16933
Project page: https://guoyww.github.io/projects/SparseCtrl/
Code: https://github.com/guoyww/AnimateDiff/
SparseCtrl RGB: https://huggingface.co/a-r-r-o-w/animatediff-sparsectrl-rgb
SparseCtrl Scribble: https://huggingface.co/a-r-r-o-w/animatediff-sparsectrl-scribble
SparseCtrl Motion Lora: https://huggingface.co/a-r-r-o-w/animatediff-motion-lora-v1-5-3
TODO:
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@DN6 @sayakpaul