Description
Is your feature request related to a problem? Please describe.
The current AnimateDiffSDXLPipeline doesn't support neither 1 controlnet nor multi controlnets.
I've been working on this task for several days by combining StableDiffusionXLControlNetAdapterPipeline and AnimateDiffControlNetPipeline in community folder but not success yet.
Describe the solution you'd like.
The idea was poses of character from a video will be extracted then by utilizing the Pose ControlnetSDXL the AnimateDiffSDXL will be conditioned on the provided information to produce another character video.
The AnimateDiffSDXLPipeline should be callable like this:
adapter = MotionAdapter.from_pretrained(
"a-r-r-o-w/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16
)
controlnet = [
ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
).to("cuda"),
ControlNetModel.from_pretrained(
"thibaud/controlnet-openpose-sdxl-1.0",
torch_dtype=torch.float16).to("cuda"),
]
# Define model ID and scheduler
# model_id = "stabilityai/stable-diffusion-xl-base-1.0"
model_id = "./pytorch_model/xl-1.0/XL_BASE/"
scheduler = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
clip_sample=False,
timestep_spacing="linspace",
beta_schedule="linear",
steps_offset=1,
)
# Load conditioning frames
conditioning_frames = []
for i in range(1, 16 + 1):
conditioning_frames.append(Image.open(f"./pose_frame/pose_extracted_000{i + 25}_.png"))
pipe = AnimateDiffSDXLControlnetPipeline.from_pretrained(
model_id,
controlnet=controlnet,
motion_adapter=adapter,
scheduler=scheduler,
torch_dtype=torch.float16,
variant="fp16",
controlnet_conditioning_scale=[0.8],
control_guidance_start=[0.0],
control_guidance_end=[1.0]
).to("cuda")
# Enable memory savings
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
# Generate the output
output = pipe(
prompt="an adorable gecko dancing in the desert, scatter lights, realistic, high quality",
negative_prompt="low quality, worst quality, extra limbs",
num_inference_steps=20,
guidance_scale=8,
width=1024,
height=1024,
num_frames=16,
conditioning_frames=conditioning_frames,
)
# Extract frames and export to GIF
frames = output.frames[0]
export_to_gif(frames, "animation.gif")
Describe alternatives you've considered.
Not yet
Additional context.
There are some shape mismatch when providing inputs for controlnet:
# controlnet(s) inference
if guess_mode and self.do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
controlnet_added_cond_kwargs = {
"text_embeds": add_text_embeds.chunk(2)[1],
"time_ids": add_time_ids.chunk(2)[1],
}
else:
control_model_input = latent_model_input_controlnet
controlnet_prompt_embeds = prompt_embeds
controlnet_added_cond_kwargs = added_cond_kwargs
controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]
control_model_input = torch.transpose(control_model_input, 1, 2)
control_model_input = control_model_input.reshape(
(-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
)
down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=conditioning_frames,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs, => mismatch shape
return_dict=False,
)
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
# down_intrablock_additional_residuals=down_intrablock_additional_residuals, # t2iadapter
down_block_additional_residuals=down_block_res_samples, # controlnet
mid_block_additional_residual=mid_block_res_sample, # controlnet
)[0]