diff --git a/mindnlp/core/nn/modules/module.py b/mindnlp/core/nn/modules/module.py index c5efd9ac8..987c628b3 100644 --- a/mindnlp/core/nn/modules/module.py +++ b/mindnlp/core/nn/modules/module.py @@ -772,7 +772,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, setattr(self, name, input_param) else: param.data_sync(True) - dtype = param.dtype + dtype = input_param.dtype param.assign_value(input_param) param.set_dtype(dtype) except Exception as ex: