!12033 Change deprecated function in master

From: @liangzhibo
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-02-03 20:31:45 +08:00 committed by Gitee
commit 0b96a36d0c
4 changed files with 56 additions and 12 deletions

View File

@ -15,12 +15,13 @@
"""Providing decorators."""
def deprecated(version, substitute):
def deprecated(version, substitute, use_substitute_name=False):
"""deprecated warning
Args:
version (str): version that the operator or function is deprecated.
substitute (str): the substitute name for deprecated operator or function.
use_substitute_name (bool): flag for whether to use substitute name for deprecated operator or function
"""
def decorate(func):
@ -29,6 +30,8 @@ def deprecated(version, substitute):
name = cls.__name__ if cls else func.__name__
print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, "
f"use '{substitute}' instead.")
if cls and use_substitute_name:
cls.substitute_name = substitute
ret = func(*args, **kwargs)
return ret

View File

@ -33,6 +33,7 @@ from .. import signature as sig
from ..._checkparam import Rel
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ...common._decorator import deprecated
from ...common.parameter import Parameter
from ...common.tensor import Tensor
@ -820,10 +821,29 @@ class Gather(PrimitiveWithCheck):
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
def GatherV2():
"""Warning: This will be changed later"""
logger.warning("WARN_DEPRECATED: The usage of GatherV2 is deprecated. Please use Gather.")
return Gather()
class GatherV2(PrimitiveWithCheck):
"""
Same as operator Gather. GatherV2 will be deprecated in the future.
Please use Gather instead.
"""
#deprecate_new_name = "Gather"
@deprecated("1.1", "Gather", True)
@prim_attr_register
def __init__(self):
"""Initialize index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
self.add_prim_attr("dynamic_shape_depends", [2])
def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name)
axis_v = axis['value']
validator.check_value_type('axis', axis_v, [int], self.name)
rank = len(params['shape'])
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
class SparseGatherV2(Gather):
"""

View File

@ -18,13 +18,13 @@
import copy
import numpy as np
from mindspore import log as logger
from ... import context
from .. import signature as sig
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ...common.tensor import Tensor
from ...common._decorator import deprecated
from .._utils import get_broadcast_shape
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
@ -161,10 +161,28 @@ class Add(_MathBinaryOp):
return Tensor(out)
return None
def TensorAdd():
"""Warning: This will be changed later"""
logger.warning("WARN_DEPRECATED: The usage of TensorAdd is deprecated. Please use Add.")
return Add()
class TensorAdd(_MathBinaryOp):
"""
Same as operator Add. TensorAdd will be deprecated in the future.
Please use Add instead.
"""
#deprecate_new_name = "Add"
@deprecated("1.1", "Add", True)
@prim_attr_register
def __init__(self):
_MathBinaryOp.__init__(self)
def infer_value(self, x, y):
if x is not None and y is not None:
x = x.asnumpy()
y = y.asnumpy()
out = x + y
out = np.array(out, x.dtype)
return Tensor(out)
return None
class AssignAdd(PrimitiveWithInfer):
"""

View File

@ -466,10 +466,13 @@ def prim_attr_register(fn):
"""
def deco(self, *args, **kwargs):
class_name = self.__class__.__name__
if hasattr(self.__class__, "substitute_name"):
class_name = self.__class__.substitute_name
if isinstance(self, PrimitiveWithInfer):
PrimitiveWithInfer.__init__(self, self.__class__.__name__)
PrimitiveWithInfer.__init__(self, class_name)
elif isinstance(self, PrimitiveWithCheck):
PrimitiveWithCheck.__init__(self, self.__class__.__name__)
PrimitiveWithCheck.__init__(self, class_name)
else:
Primitive.__init__(self, self.__class__.__name__)
bound_args = inspect.signature(fn).bind(self, *args, **kwargs)