Skip to content

AnimateDiffSDXL + Multi Controlnets support #8664

Open
@Sundragon1993

Description

@Sundragon1993

Is your feature request related to a problem? Please describe.
The current AnimateDiffSDXLPipeline doesn't support neither 1 controlnet nor multi controlnets.
I've been working on this task for several days by combining StableDiffusionXLControlNetAdapterPipeline and AnimateDiffControlNetPipeline in community folder but not success yet.

Describe the solution you'd like.
The idea was poses of character from a video will be extracted then by utilizing the Pose ControlnetSDXL the AnimateDiffSDXL will be conditioned on the provided information to produce another character video.

The AnimateDiffSDXLPipeline should be callable like this:

adapter = MotionAdapter.from_pretrained(
    "a-r-r-o-w/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16
)

controlnet = [
     ControlNetModel.from_pretrained(
         "diffusers/controlnet-depth-sdxl-1.0",
         torch_dtype=torch.float16,
         variant="fp16",
         use_safetensors=True
    ).to("cuda"),
    ControlNetModel.from_pretrained(
        "thibaud/controlnet-openpose-sdxl-1.0",
        torch_dtype=torch.float16).to("cuda"),
]

# Define model ID and scheduler
# model_id = "stabilityai/stable-diffusion-xl-base-1.0"
model_id = "./pytorch_model/xl-1.0/XL_BASE/"
scheduler = DDIMScheduler.from_pretrained(
    model_id,
    subfolder="scheduler",
    clip_sample=False,
    timestep_spacing="linspace",
    beta_schedule="linear",
    steps_offset=1,
)

# Load conditioning frames
conditioning_frames = []
for i in range(1, 16 + 1):
    conditioning_frames.append(Image.open(f"./pose_frame/pose_extracted_000{i + 25}_.png"))

pipe = AnimateDiffSDXLControlnetPipeline.from_pretrained(
    model_id,
    controlnet=controlnet,
    motion_adapter=adapter,
    scheduler=scheduler,
    torch_dtype=torch.float16,
    variant="fp16",
    controlnet_conditioning_scale=[0.8],
    control_guidance_start=[0.0],
    control_guidance_end=[1.0]
).to("cuda")

# Enable memory savings
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()

# Generate the output
output = pipe(
    prompt="an adorable gecko dancing in the desert, scatter lights, realistic, high quality",
    negative_prompt="low quality, worst quality, extra limbs",
    num_inference_steps=20,
    guidance_scale=8,
    width=1024,
    height=1024,
    num_frames=16,
    conditioning_frames=conditioning_frames,
)

# Extract frames and export to GIF
frames = output.frames[0]
export_to_gif(frames, "animation.gif")

Describe alternatives you've considered.
Not yet

Additional context.
There are some shape mismatch when providing inputs for controlnet:

# controlnet(s) inference
                    if guess_mode and self.do_classifier_free_guidance:
                        # Infer ControlNet only for the conditional batch.
                        control_model_input = latents
                        control_model_input = self.scheduler.scale_model_input(control_model_input, t)
                        controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
                        controlnet_added_cond_kwargs = {
                            "text_embeds": add_text_embeds.chunk(2)[1],
                            "time_ids": add_time_ids.chunk(2)[1],
                        }
                    else:
                        control_model_input = latent_model_input_controlnet
                        controlnet_prompt_embeds = prompt_embeds
                        controlnet_added_cond_kwargs = added_cond_kwargs

                    controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)
                    if isinstance(controlnet_keep[i], list):
                        cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
                    else:
                        controlnet_cond_scale = controlnet_conditioning_scale
                        if isinstance(controlnet_cond_scale, list):
                            controlnet_cond_scale = controlnet_cond_scale[0]
                        cond_scale = controlnet_cond_scale * controlnet_keep[i]

                    control_model_input = torch.transpose(control_model_input, 1, 2)
                    control_model_input = control_model_input.reshape(
                        (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
                    )

                    down_block_res_samples, mid_block_res_sample = self.controlnet(
                        control_model_input,
                        t,
                        encoder_hidden_states=controlnet_prompt_embeds,
                        controlnet_cond=conditioning_frames,
                        conditioning_scale=cond_scale,
                        guess_mode=guess_mode,
                        added_cond_kwargs=controlnet_added_cond_kwargs, => mismatch shape
                        return_dict=False,
                    )

                    noise_pred = self.unet(
                        latent_model_input,
                        t,
                        encoder_hidden_states=prompt_embeds,
                        cross_attention_kwargs=cross_attention_kwargs,
                        added_cond_kwargs=added_cond_kwargs,
                        return_dict=False,
                        # down_intrablock_additional_residuals=down_intrablock_additional_residuals,  # t2iadapter
                        down_block_additional_residuals=down_block_res_samples,  # controlnet
                        mid_block_additional_residual=mid_block_res_sample,  # controlnet
                    )[0]

Metadata

Metadata

Assignees

No one assigned

    Labels

    consider-for-modular-diffusersThings to consider adding support for in Modular Diffusers (with the help of community)contributions-welcomestaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions