File tree Expand file tree Collapse file tree 2 files changed +1
-5
lines changed
patch/worker/patch_common Expand file tree Collapse file tree 2 files changed +1
-5
lines changed Original file line number Diff line number Diff line change 1
-
2
1
import torch
3
2
from torch .nn .parameter import Parameter
4
3
from vllm .logger import init_logger
Original file line number Diff line number Diff line change 33
33
from vllm .model_executor .layers .quantization .kv_cache import BaseKVCacheMethod
34
34
from vllm .model_executor .layers .vocab_parallel_embedding import (
35
35
UnquantizedEmbeddingMethod , VocabParallelEmbedding )
36
- from vllm .model_executor .parameter import PerTensorScaleParameter
37
36
from vllm .model_executor .utils import set_weight_attrs
38
37
39
38
from vllm_ascend .distributed .parallel_state import (get_mlp_tp_group ,
@@ -251,7 +250,6 @@ def create_weights(
251
250
** extra_weight_attrs ,
252
251
) -> None :
253
252
output_size_per_partition = sum (output_partition_sizes )
254
- weight_loader = extra_weight_attrs .get ("weight_loader" )
255
253
256
254
weight_dict = self .quant_method .get_weight (input_size_per_partition ,
257
255
output_size_per_partition ,
@@ -264,8 +262,7 @@ def create_weights(
264
262
265
263
pertensor_dict = self .quant_method .get_pertensor_param (params_dtype )
266
264
for pertensor_name , pertensor_param in pertensor_dict .items ():
267
- param = PerTensorScaleParameter (data = pertensor_param ,
268
- weight_loader = weight_loader )
265
+ param = torch .nn .Parameter (pertensor_param , requires_grad = False )
269
266
# disable warning
270
267
param .ignore_warning = True
271
268
layer .register_parameter (pertensor_name , param )
You can’t perform that action at this time.
0 commit comments