-
Notifications
You must be signed in to change notification settings - Fork 448
Support for Qwen3-VL models #911
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Please add support for Qwen3-VL . |
|
Hi @dahwin |
Here's the updated message for @mayankagarwals with all the findings: Thanks for working on Qwen3-VL support! I've done extensive benchmarking of PR #911 with **Qwen3-VL-4B-Instruct** and discovered some critical issues.
## Setup
- **Model**: Qwen3-VL-4B-Instruct (dense, non-MoE)
- **Hardware**: 4x NVIDIA L40S (48GB each)
- **Framework**: MS-Swift with Flash Attention + Liger Kernel (PR #911 branch)
- **Training Config**: batch_size=1, grad_accum=2, bf16, gradient_checkpointing, full parameter training
- **Comparison**: Tested both with and without Liger Kernel using identical hyperparameters
## Test 1: Using `apply_liger_kernel_to_qwen2_vl()` (Initial Attempt)
Since your comment said "The branch already supports FLCE for qwen 3 VL", I initially tried using `apply_liger_kernel_to_qwen2_vl()` as a fallback:
```python
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
apply_liger_kernel_to_qwen2_vl(
fused_linear_cross_entropy=True,
rms_norm=True,
rope=True,
swiglu=True,
)Results:
Conclusion: Using Test 2: Using
|
|
|
@dahwin Hi, thank you for your testings. There's a detailed analysis on memory with and without liger FLCE in #517. TLDR, it only cut the memory usage of logits related tensors, so you don't see much difference with low batch size/short seqlen. Lowering the memory wall due to logits, you can achieve more efficient training by increasing batch size/seqlen as you mentioned, or disabling gradient checkpointing. |
|
@mayankagarwals
|
|
Hi @Tcc0403 Apologies for the delay, running a little low on time Current status: If you feel this is time sensitive, feel free to bring it home. I'll try to take out more time soon and close. |
|
Update: So the following test is current failing and requires solving |
|
Verified the bf16/fp32 convergence tests pass on h100. Fixed the other failing convergence tests. LGTM! |
shimizust
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the effort in adding support for qwen3-vl model @mayankagarwals @Tcc0403
## Summary tiny fix re: #930 , grad_weight and grad_bias were never set on no_grad path First commit is my fix, second commit was from running `make checkstyle` and a small import change for `qwen3_vl` as the import was broken and tests were not passing? Believe this was introduced in #911 ## Testing Done Repro provided in #930 now passes: ```python import torch import torch.nn as nn from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss vocab_size, hidden_dim, num_tokens = 1000, 512, 256 device = "cuda" if torch.cuda.is_available() else "cpu" linear = nn.Linear(hidden_dim, vocab_size, bias=False).to(device) fused_loss_fn = LigerFusedLinearCrossEntropyLoss() hidden_states = torch.randn(num_tokens, hidden_dim, device=device) labels = torch.randint(0, vocab_size, (num_tokens,), device=device) with torch.no_grad(): loss = fused_loss_fn(linear.weight, hidden_states, labels) print(f"Loss: {loss.item()}") ``` - Hardware Type: 3090 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence > 2110 passed, 255 skipped, 41 warnings, 1 rerun in 276.17s (0:04:36)
|
Hi, is it possible to add SwiGLU support for Qwen3-VL? It seems that currently passing swiglu=True in monkey patch function is a no-op. |
|
@matthewdm0816 Feel free to open an issue for it! I'll work on it. |
|
@matthewdm0816 The ROPE operator seems to have some issues. Trainer + Qwen3-VL-8B: [rank6]: return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns, |
Summary
Support for Qwen3-VL models
Solves #897
Details
NA
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence