Change the Parameter type name from 'ParameterTensor' to 'Parameter'

This commit is contained in:
Zhang Qinghua 2022-03-08 09:27:21 +08:00
parent 8220cd601a
commit 973008ebcf
3 changed files with 3 additions and 3 deletions

View File

@ -216,7 +216,7 @@ class Parameter(Tensor_):
@staticmethod
def _get_base_class(input_class):
input_class_name = f'Parameter{input_class.__name__}'
input_class_name = Parameter.__name__
if input_class_name in Parameter.__base_type__:
new_type = Parameter.__base_type__[input_class_name]
else:

View File

@ -101,7 +101,7 @@ def test_outermost_net_pass_parameter():
"support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 1th arg type is <class 'mindspore.common.parameter.ParameterTensor'>, " \
"but the 1th arg type is <class 'mindspore.common.parameter.Parameter'>, " \
"value is 'Parameter (name=weight, shape=(2, 2), dtype=Float32, requires_grad=True)'" \
in str(err.value)

View File

@ -103,7 +103,7 @@ def test_outermost_net_pass_parameter():
"support bool, int, float, None, tensor, " \
"mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint), " \
"and tuple or list containing only these types, and dict whose values are these types, " \
"but the 1th arg type is <class 'mindspore.common.parameter.ParameterTensor'>, " \
"but the 1th arg type is <class 'mindspore.common.parameter.Parameter'>, " \
"value is 'Parameter (name=weight, shape=(2, 2), dtype=Float32, requires_grad=True)'" \
in str(err.value)