3
3
import gc
4
4
import itertools
5
5
import math
6
+ from functools import partial
6
7
from typing import TYPE_CHECKING , Any , Dict , List , Optional , Type , Union , cast
7
8
8
9
import habana_frameworks .torch as htorch
@@ -42,6 +43,13 @@ class HpuModelAdapterEncoderDecoder(HpuModelAdapter):
42
43
def __init__ (self , model , vllm_config , layer_names , is_causal , sampler ):
43
44
super ().__init__ (model , vllm_config , layer_names , is_causal , sampler )
44
45
46
+ # We only wrap the language model in HPU graph because some Ops in
47
+ # vision model will fallback to CPU and cause the graph building fail.
48
+ if htorch .utils .internal .is_lazy () and hasattr (self .model ,
49
+ "language_model" ):
50
+ self .model .language_model = htorch .hpu .wrap_in_hpu_graph (
51
+ self .model .language_model , disable_tensor_cache = True )
52
+
45
53
def _set_cross_block_mapping (self , metadata , batch_size , device , dtype ):
46
54
mask = torch .arange (0 ,
47
55
self .block_size ,
@@ -110,6 +118,13 @@ def forward(self, *args, **kwargs):
110
118
kwargs ['attn_metadata' ] = self ._update_cross_metadata (
111
119
kwargs ['attn_metadata' ], input_ids .size (0 ), input_ids .size (1 ),
112
120
input_ids .device , self .dtype )
121
+ if htorch .utils .internal .is_lazy () and hasattr (self .model ,
122
+ "language_model" ):
123
+ bypass_hpu_graphs = kwargs .get ('bypass_hpu_graphs' , False )
124
+ self .model .language_model .forward = partial (
125
+ self .model .language_model .forward ,
126
+ attn_metadata = kwargs ['attn_metadata' ],
127
+ bypass_hpu_graphs = bypass_hpu_graphs )
113
128
# TODO: Change the input_ids to 1D to match the public vllm
114
129
# implementation and avoid shape mismatch issues with some
115
130
# models(i.e. Mllama). But currently this will cause graph
@@ -118,9 +133,9 @@ def forward(self, *args, **kwargs):
118
133
virtual_engine = 0
119
134
if 'virtual_engine' in kwargs :
120
135
virtual_engine = kwargs .pop ('virtual_engine' )
136
+ attn_metadata = kwargs .pop ('attn_metadata' )
121
137
if 'kv_caches' in kwargs :
122
138
kwargs .pop ('kv_caches' )
123
- attn_metadata = kwargs .pop ("attn_metadata" )
124
139
with set_forward_context (attn_metadata , self .vllm_config ,
125
140
virtual_engine ):
126
141
hidden_states = self .model (* args , ** kwargs )
@@ -193,11 +208,7 @@ def _flatten(self, in_list):
193
208
return list (itertools .chain (* in_list ))
194
209
195
210
def _maybe_wrap_in_hpu_graph (self , * args , ** kwargs ):
196
- return htorch .hpu .wrap_in_hpu_graph (
197
- HpuModelAdapterEncoderDecoder (* args , ** kwargs ),
198
- disable_tensor_cache = True ,
199
- ) if htorch .utils .internal .is_lazy (
200
- ) else HpuModelAdapterEncoderDecoder (* args , ** kwargs )
211
+ return HpuModelAdapterEncoderDecoder (* args , ** kwargs )
201
212
202
213
def prepare_model_input (
203
214
self ,
0 commit comments