Skip to content

Commit 9f1222c

Browse files
kdamaszkYu-Zhoujkaniecki
authored
Fix OOM in fp8 w/ HPUGraph for llama3.2 (#1365) (#1479)
Signed-off-by: zhouyu5 <[email protected]> Co-authored-by: Yu-Zhou <[email protected]> Co-authored-by: Jan Kaniecki <[email protected]>
1 parent 569c0aa commit 9f1222c

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

vllm/model_executor/models/mllama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,6 +1350,7 @@ def forward(
13501350
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
13511351
torch.Tensor]],
13521352
skip_cross_attention: bool,
1353+
**kwargs: object,
13531354
) -> torch.Tensor:
13541355
hidden_states = self.model(
13551356
input_ids=input_ids,

vllm/worker/hpu_enc_dec_model_runner.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import gc
44
import itertools
55
import math
6+
from functools import partial
67
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union, cast
78

89
import habana_frameworks.torch as htorch
@@ -42,6 +43,13 @@ class HpuModelAdapterEncoderDecoder(HpuModelAdapter):
4243
def __init__(self, model, vllm_config, layer_names, is_causal, sampler):
4344
super().__init__(model, vllm_config, layer_names, is_causal, sampler)
4445

46+
# We only wrap the language model in HPU graph because some Ops in
47+
# vision model will fallback to CPU and cause the graph building fail.
48+
if htorch.utils.internal.is_lazy() and hasattr(self.model,
49+
"language_model"):
50+
self.model.language_model = htorch.hpu.wrap_in_hpu_graph(
51+
self.model.language_model, disable_tensor_cache=True)
52+
4553
def _set_cross_block_mapping(self, metadata, batch_size, device, dtype):
4654
mask = torch.arange(0,
4755
self.block_size,
@@ -110,6 +118,13 @@ def forward(self, *args, **kwargs):
110118
kwargs['attn_metadata'] = self._update_cross_metadata(
111119
kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1),
112120
input_ids.device, self.dtype)
121+
if htorch.utils.internal.is_lazy() and hasattr(self.model,
122+
"language_model"):
123+
bypass_hpu_graphs = kwargs.get('bypass_hpu_graphs', False)
124+
self.model.language_model.forward = partial(
125+
self.model.language_model.forward,
126+
attn_metadata=kwargs['attn_metadata'],
127+
bypass_hpu_graphs=bypass_hpu_graphs)
113128
# TODO: Change the input_ids to 1D to match the public vllm
114129
# implementation and avoid shape mismatch issues with some
115130
# models(i.e. Mllama). But currently this will cause graph
@@ -118,9 +133,9 @@ def forward(self, *args, **kwargs):
118133
virtual_engine = 0
119134
if 'virtual_engine' in kwargs:
120135
virtual_engine = kwargs.pop('virtual_engine')
136+
attn_metadata = kwargs.pop('attn_metadata')
121137
if 'kv_caches' in kwargs:
122138
kwargs.pop('kv_caches')
123-
attn_metadata = kwargs.pop("attn_metadata")
124139
with set_forward_context(attn_metadata, self.vllm_config,
125140
virtual_engine):
126141
hidden_states = self.model(*args, **kwargs)
@@ -193,11 +208,7 @@ def _flatten(self, in_list):
193208
return list(itertools.chain(*in_list))
194209

195210
def _maybe_wrap_in_hpu_graph(self, *args, **kwargs):
196-
return htorch.hpu.wrap_in_hpu_graph(
197-
HpuModelAdapterEncoderDecoder(*args, **kwargs),
198-
disable_tensor_cache=True,
199-
) if htorch.utils.internal.is_lazy(
200-
) else HpuModelAdapterEncoderDecoder(*args, **kwargs)
211+
return HpuModelAdapterEncoderDecoder(*args, **kwargs)
201212

202213
def prepare_model_input(
203214
self,

0 commit comments

Comments
 (0)