Skip to content

support .alpha keys in HiDream loras trained using OneTrainer #11653

Open
@ali-afridi26

Description

@ali-afridi26

Describe the bug

If a Hidream lora is trained using OneTrainer, some keys of the checkpoint have the ".alpha" in them instead of all of them including 'lora'. So this condition fails: is_correct_format = all("lora" in key for key in state_dict.keys()) in the load_lora_weights method of HiDreamImageLoraLoaderMixin

Reproduction

import torch
from transformers import AutoTokenizer, LlamaForCausalLM
from diffusers import HiDreamImagePipeline

from huggingface_hub import login
login("your_token")

text_encoder_4 = LlamaForCausalLM.from_pretrained(
                "terminusresearch/hidream-i1-llama-3.1-8b-instruct",
                subfolder="text_encoder_4",
                output_hidden_states=True,
                output_attentions=True,
                torch_dtype=torch.bfloat16,
            ).to("cuda", dtype=torch.bfloat16)
tokenizer_4 = AutoTokenizer.from_pretrained(
            "terminusresearch/hidream-i1-llama-3.1-8b-instruct",
            subfolder="tokenizer_4",
            config=text_encoder_4.config,
        ).to("cuda", dtype=torch.bfloat16)
pipe = HiDreamImagePipeline.from_pretrained(
    "HiDream-ai/HiDream-I1-Dev",
    text_encoder_4=text_encoder_4,
    tokenizer_4 = tokenizer_4,
    torch_dtype=torch.bfloat16,

).to("cuda", dtype=torch.bfloat16)
pipe.load_lora_weights(f"RhaegarKhan/OMI_LORA")
image = pipe(
    'A cat holding a sign that says "Hi-Dreams.ai".',
    height=1024,
    width=1024,
    guidance_scale=5.0,
    num_inference_steps=50,
    generator=torch.Generator("cuda").manual_seed(0),
).images[0]
image.save("output.png")

Logs

(.venv) [root@dual-gpu-test sd-base-api]# python ali.py 
Multiple distributions found for package optimum. Picked distribution: optimum-quanto
/runware/sd-base-api/.venv/lib64/python3.12/site-packages/transformers/generation/configuration_utils.py:820: UserWarning: `return_dict_in_generate` is NOT set to `True`, but `output_attentions` is. When `return_dict_in_generate` is not `True`, `output_attentions` is ignored.
  warnings.warn(
/runware/sd-base-api/.venv/lib64/python3.12/site-packages/transformers/generation/configuration_utils.py:820: UserWarning: `return_dict_in_generate` is NOT set to `True`, but `output_hidden_states` is. When `return_dict_in_generate` is not `True`, `output_hidden_states` is ignored.
  warnings.warn(
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 132.45it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 92.25it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 69.48it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  8.14it/s]
hidream_lora_full.safetensors: 100%|████████████████████████████████████████████████████████████████████████| 30.8M/30.8M [00:00<00:00, 42.7MB/s]
Traceback (most recent call last):
  File "/runware/sd-base-api/ali.py", line 31, in <module>
    pipe.load_lora_weights(f"RhaegarKhan/OMI_LORA")
  File "/runware/sd-base-api/.venv/lib64/python3.12/site-packages/diffusers/loaders/lora_pipeline.py", line 5545, in load_lora_weights
    raise ValueError("Invalid LoRA checkpoint.")
ValueError: Invalid LoRA checkpoint.

System Info

Multiple distributions found for package optimum. Picked distribution: optimum-quanto

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • 🤗 Diffusers version: 0.34.0.dev0
  • Platform: Linux-5.15.0-136-generic-x86_64-with-glibc2.34
  • Running on Google Colab?: No
  • Python version: 3.12.9
  • PyTorch version (GPU?): 2.7.0+cu126 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.31.2
  • Transformers version: 4.51.3
  • Accelerate version: 1.6.0
  • PEFT version: 0.15.2
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA H200, 143771 MiB
    NVIDIA H200, 143771 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

(.venv) [root@dual-gpu-test sd-base-api]#

Who can help?

@sayakpaul

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions