forked from mindspore-Ecosystem/mindspore
!12033 Change deprecated function in master
From: @liangzhibo Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qh
This commit is contained in:
commit
0b96a36d0c
|
@ -15,12 +15,13 @@
|
|||
"""Providing decorators."""
|
||||
|
||||
|
||||
def deprecated(version, substitute):
|
||||
def deprecated(version, substitute, use_substitute_name=False):
|
||||
"""deprecated warning
|
||||
|
||||
Args:
|
||||
version (str): version that the operator or function is deprecated.
|
||||
substitute (str): the substitute name for deprecated operator or function.
|
||||
use_substitute_name (bool): flag for whether to use substitute name for deprecated operator or function
|
||||
"""
|
||||
|
||||
def decorate(func):
|
||||
|
@ -29,6 +30,8 @@ def deprecated(version, substitute):
|
|||
name = cls.__name__ if cls else func.__name__
|
||||
print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, "
|
||||
f"use '{substitute}' instead.")
|
||||
if cls and use_substitute_name:
|
||||
cls.substitute_name = substitute
|
||||
ret = func(*args, **kwargs)
|
||||
return ret
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ from .. import signature as sig
|
|||
from ..._checkparam import Rel
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...common import dtype as mstype
|
||||
from ...common._decorator import deprecated
|
||||
from ...common.parameter import Parameter
|
||||
from ...common.tensor import Tensor
|
||||
|
||||
|
@ -820,10 +821,29 @@ class Gather(PrimitiveWithCheck):
|
|||
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
|
||||
|
||||
|
||||
def GatherV2():
|
||||
"""Warning: This will be changed later"""
|
||||
logger.warning("WARN_DEPRECATED: The usage of GatherV2 is deprecated. Please use Gather.")
|
||||
return Gather()
|
||||
class GatherV2(PrimitiveWithCheck):
|
||||
"""
|
||||
Same as operator Gather. GatherV2 will be deprecated in the future.
|
||||
Please use Gather instead.
|
||||
"""
|
||||
#deprecate_new_name = "Gather"
|
||||
|
||||
@deprecated("1.1", "Gather", True)
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize index_select"""
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||
self.add_prim_attr("dynamic_shape_depends", [2])
|
||||
|
||||
def __check__(self, params, indices, axis):
|
||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
|
||||
validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name)
|
||||
axis_v = axis['value']
|
||||
validator.check_value_type('axis', axis_v, [int], self.name)
|
||||
rank = len(params['shape'])
|
||||
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
|
||||
|
||||
|
||||
class SparseGatherV2(Gather):
|
||||
"""
|
||||
|
|
|
@ -18,13 +18,13 @@
|
|||
import copy
|
||||
|
||||
import numpy as np
|
||||
from mindspore import log as logger
|
||||
from ... import context
|
||||
from .. import signature as sig
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
from ...common.tensor import Tensor
|
||||
from ...common._decorator import deprecated
|
||||
from .._utils import get_broadcast_shape
|
||||
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
|
||||
|
||||
|
@ -161,10 +161,28 @@ class Add(_MathBinaryOp):
|
|||
return Tensor(out)
|
||||
return None
|
||||
|
||||
def TensorAdd():
|
||||
"""Warning: This will be changed later"""
|
||||
logger.warning("WARN_DEPRECATED: The usage of TensorAdd is deprecated. Please use Add.")
|
||||
return Add()
|
||||
|
||||
class TensorAdd(_MathBinaryOp):
|
||||
"""
|
||||
Same as operator Add. TensorAdd will be deprecated in the future.
|
||||
Please use Add instead.
|
||||
"""
|
||||
#deprecate_new_name = "Add"
|
||||
|
||||
@deprecated("1.1", "Add", True)
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
_MathBinaryOp.__init__(self)
|
||||
|
||||
def infer_value(self, x, y):
|
||||
if x is not None and y is not None:
|
||||
x = x.asnumpy()
|
||||
y = y.asnumpy()
|
||||
out = x + y
|
||||
out = np.array(out, x.dtype)
|
||||
return Tensor(out)
|
||||
return None
|
||||
|
||||
|
||||
class AssignAdd(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
|
@ -466,10 +466,13 @@ def prim_attr_register(fn):
|
|||
"""
|
||||
|
||||
def deco(self, *args, **kwargs):
|
||||
class_name = self.__class__.__name__
|
||||
if hasattr(self.__class__, "substitute_name"):
|
||||
class_name = self.__class__.substitute_name
|
||||
if isinstance(self, PrimitiveWithInfer):
|
||||
PrimitiveWithInfer.__init__(self, self.__class__.__name__)
|
||||
PrimitiveWithInfer.__init__(self, class_name)
|
||||
elif isinstance(self, PrimitiveWithCheck):
|
||||
PrimitiveWithCheck.__init__(self, self.__class__.__name__)
|
||||
PrimitiveWithCheck.__init__(self, class_name)
|
||||
else:
|
||||
Primitive.__init__(self, self.__class__.__name__)
|
||||
bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue