Skip to content

Commit c4835ab

Browse files
Fix LoRALinear throwing errors (#2909)
Signed-off-by: Emmanuel Ferdman <[email protected]>
1 parent 97bd210 commit c4835ab

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

tests/torchtune/modules/peft/test_lora.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,16 @@ def test_quantized_state_dict(self, dtype):
237237
)
238238

239239
def test_qat_lora_forward(self, inputs, lora_linear, out_dim) -> None:
240-
lora_linear = lora_linear(use_bias=True, dtype=torch.float32)
240+
lora_linear = lora_linear(use_bias=False, dtype=torch.float32)
241241
qat_lora_linear = QATLoRALinear.from_lora_linear(lora_linear)
242242
expected = torch.tensor(QAT_EXPECTED_VAL)
243243
actual = qat_lora_linear(inputs)
244244
assert actual.shape == (BSZ, SEQ_LEN, out_dim)
245245
torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6)
246+
247+
def test_qat_lora_with_bias_raises_error(self, lora_linear) -> None:
248+
lora_linear_with_bias = lora_linear(use_bias=True, dtype=torch.float32)
249+
with pytest.raises(
250+
ValueError, match="Bias is not supported in QAT \\+ LoRA yet"
251+
):
252+
QATLoRALinear.from_lora_linear(lora_linear_with_bias)

torchtune/modules/peft/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,9 @@ def from_lora_linear(
277277
preserving the weights and adapters.
278278
"""
279279
if lora_linear.bias is not None:
280-
ValueError("Bias is not supported in QAT + LoRA yet")
280+
raise ValueError("Bias is not supported in QAT + LoRA yet")
281281
if lora_linear._quantize_base:
282-
ValueError("quantize_base is not compatible with QAT + LoRA")
282+
raise ValueError("quantize_base is not compatible with QAT + LoRA")
283283
if isinstance(lora_linear.dropout, nn.Dropout):
284284
dropout = lora_linear.dropout.p
285285
else:

0 commit comments

Comments
 (0)