Skip to content

Commit 5c89882

Browse files
committed
fix patch
Signed-off-by: wangxiyuan <[email protected]>
1 parent 353ecf6 commit 5c89882

File tree

1 file changed

+103
-0
lines changed

1 file changed

+103
-0
lines changed

vllm_ascend/patch/worker/patch_common/patch_weight_loader.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
from typing import Callable, Optional
2+
13
import torch
24
from torch.nn.parameter import Parameter
5+
from vllm.distributed import (get_tensor_model_parallel_rank,
6+
get_tensor_model_parallel_world_size)
37
from vllm.logger import init_logger
48
from vllm.model_executor.utils import set_weight_attrs
59
from vllm.utils import GiB_bytes
@@ -39,6 +43,105 @@ def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int,
3943
set_weight_attrs(weight, extra_weight_attrs)
4044

4145

46+
class CustomBasevLLMParameter(Parameter):
47+
"""
48+
Base parameter for vLLM linear layers. Extends the torch.nn.parameter
49+
by taking in a linear weight loader. Will copy the loaded weight
50+
into the parameter when the provided weight loader is called.
51+
"""
52+
53+
def __new__(cls, data: Optional[torch.Tensor], **kwargs):
54+
55+
return super().__new__(cls, data=data, requires_grad=False)
56+
57+
def __init__(self, data: torch.Tensor, weight_loader: Callable):
58+
"""
59+
Initialize the BasevLLMParameter
60+
61+
:param data: torch tensor with the parameter data
62+
:param weight_loader: weight loader callable
63+
64+
:returns: a torch.nn.parameter
65+
"""
66+
67+
# During weight loading, we often do something like:
68+
# narrowed_tensor = param.data.narrow(0, offset, len)
69+
# narrowed_tensor.copy_(real_weight)
70+
# expecting narrowed_tensor and param.data to share the same storage.
71+
# However, on TPUs, narrowed_tensor will lazily propagate to the base
72+
# tensor, which is param.data, leading to the redundant memory usage.
73+
# This sometimes causes OOM errors during model loading. To avoid this,
74+
# we sync the param tensor after its weight loader is called.
75+
from vllm.platforms import current_platform
76+
if current_platform.use_sync_weight_loader():
77+
weight_loader = current_platform.make_synced_weight_loader(
78+
weight_loader)
79+
80+
self._weight_loader = weight_loader
81+
self.tp_rank = get_tensor_model_parallel_rank()
82+
self.tp_size = get_tensor_model_parallel_world_size()
83+
84+
@property
85+
def weight_loader(self):
86+
# NOTE(@ksayers) some models such as mamba_mixer2 override the
87+
# weight loader to support custom loading. In the future, model-specific
88+
# weight loading should be implemented via Model.load_weights. In the
89+
# meantime, support deleting and overriding `weight_loader`` attribute
90+
if self._weight_loader is None:
91+
raise AttributeError(f"{self.__class__.__name__} weight_loader "
92+
"attribute has been deleted")
93+
return self._weight_loader
94+
95+
@weight_loader.setter
96+
def weight_loader(self, value):
97+
self._weight_loader = value
98+
99+
@weight_loader.deleter
100+
def weight_loader(self):
101+
self._weight_loader = None # type: ignore[assignment]
102+
103+
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
104+
cond1 = self.data.ndim == 1 and self.data.numel() == 1
105+
cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
106+
return (cond1 and cond2)
107+
108+
def _assert_and_load(self, loaded_weight: torch.Tensor):
109+
assert (self.data.shape == loaded_weight.shape
110+
or self._is_1d_and_scalar(loaded_weight))
111+
self.data.copy_(loaded_weight)
112+
113+
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
114+
self._assert_and_load(loaded_weight)
115+
116+
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
117+
self._assert_and_load(loaded_weight)
118+
119+
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
120+
self._assert_and_load(loaded_weight)
121+
122+
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
123+
self._assert_and_load(loaded_weight)
124+
125+
def _shard_id_as_int(self, shard_id) -> int:
126+
if isinstance(shard_id, int):
127+
return shard_id
128+
129+
# if not int, assume shard_id for qkv
130+
# map to int and return
131+
qkv_idxs = {"q": 0, "k": 1, "v": 2}
132+
assert isinstance(shard_id, str)
133+
assert shard_id in qkv_idxs
134+
return qkv_idxs[shard_id]
135+
136+
@classmethod
137+
def __torch_function__(cls, func, types, args=(), kwargs=None):
138+
if kwargs is None:
139+
kwargs = {}
140+
return super().__torch_function__(func, types, args, kwargs)
141+
142+
42143
if not vllm_version_is("0.10.2"):
43144
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
44145
UnquantizedLinearMethod.create_weights = create_weights
146+
import vllm
147+
vllm.model_executor.parameter.BasevLLMParameter = CustomBasevLLMParameter

0 commit comments

Comments
 (0)