From edbe3bfd3bd1a6b3c9e16b58ac60b3f4454d79aa Mon Sep 17 00:00:00 2001 From: l00591931 Date: Wed, 3 Feb 2021 11:10:20 +0800 Subject: [PATCH] Add deprecated function --- mindspore/common/_decorator.py | 5 ++++- mindspore/ops/operations/array_ops.py | 28 +++++++++++++++++++++++---- mindspore/ops/operations/math_ops.py | 28 ++++++++++++++++++++++----- mindspore/ops/primitive.py | 7 +++++-- 4 files changed, 56 insertions(+), 12 deletions(-) diff --git a/mindspore/common/_decorator.py b/mindspore/common/_decorator.py index 892d76548f7..7f75fd17ae2 100644 --- a/mindspore/common/_decorator.py +++ b/mindspore/common/_decorator.py @@ -15,12 +15,13 @@ """Providing decorators.""" -def deprecated(version, substitute): +def deprecated(version, substitute, use_substitute_name=False): """deprecated warning Args: version (str): version that the operator or function is deprecated. substitute (str): the substitute name for deprecated operator or function. + use_substitute_name (bool): flag for whether to use substitute name for deprecated operator or function """ def decorate(func): @@ -29,6 +30,8 @@ def deprecated(version, substitute): name = cls.__name__ if cls else func.__name__ print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, " f"use '{substitute}' instead.") + if cls and use_substitute_name: + cls.substitute_name = substitute ret = func(*args, **kwargs) return ret diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 99444dd13ed..a792910808e 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -33,6 +33,7 @@ from .. import signature as sig from ..._checkparam import Rel from ..._checkparam import Validator as validator from ...common import dtype as mstype +from ...common._decorator import deprecated from ...common.parameter import Parameter from ...common.tensor import Tensor @@ -820,10 +821,29 @@ class Gather(PrimitiveWithCheck): validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) -def GatherV2(): - """Warning: This will be changed later""" - logger.warning("WARN_DEPRECATED: The usage of GatherV2 is deprecated. Please use Gather.") - return Gather() +class GatherV2(PrimitiveWithCheck): + """ + Same as operator Gather. GatherV2 will be deprecated in the future. + Please use Gather instead. + """ + #deprecate_new_name = "Gather" + + @deprecated("1.1", "Gather", True) + @prim_attr_register + def __init__(self): + """Initialize index_select""" + self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) + self.add_prim_attr("dynamic_shape_depends", [2]) + + def __check__(self, params, indices, axis): + validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) + validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) + validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name) + axis_v = axis['value'] + validator.check_value_type('axis', axis_v, [int], self.name) + rank = len(params['shape']) + validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) + class SparseGatherV2(Gather): """ diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index fc5714e39c8..2737082655d 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -18,13 +18,13 @@ import copy import numpy as np -from mindspore import log as logger from ... import context from .. import signature as sig from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype from ...common.tensor import Tensor +from ...common._decorator import deprecated from .._utils import get_broadcast_shape from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op @@ -161,10 +161,28 @@ class Add(_MathBinaryOp): return Tensor(out) return None -def TensorAdd(): - """Warning: This will be changed later""" - logger.warning("WARN_DEPRECATED: The usage of TensorAdd is deprecated. Please use Add.") - return Add() + +class TensorAdd(_MathBinaryOp): + """ + Same as operator Add. TensorAdd will be deprecated in the future. + Please use Add instead. + """ + #deprecate_new_name = "Add" + + @deprecated("1.1", "Add", True) + @prim_attr_register + def __init__(self): + _MathBinaryOp.__init__(self) + + def infer_value(self, x, y): + if x is not None and y is not None: + x = x.asnumpy() + y = y.asnumpy() + out = x + y + out = np.array(out, x.dtype) + return Tensor(out) + return None + class AssignAdd(PrimitiveWithInfer): """ diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index f0684bae4e7..ade3d1d09cb 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -466,10 +466,13 @@ def prim_attr_register(fn): """ def deco(self, *args, **kwargs): + class_name = self.__class__.__name__ + if hasattr(self.__class__, "substitute_name"): + class_name = self.__class__.substitute_name if isinstance(self, PrimitiveWithInfer): - PrimitiveWithInfer.__init__(self, self.__class__.__name__) + PrimitiveWithInfer.__init__(self, class_name) elif isinstance(self, PrimitiveWithCheck): - PrimitiveWithCheck.__init__(self, self.__class__.__name__) + PrimitiveWithCheck.__init__(self, class_name) else: Primitive.__init__(self, self.__class__.__name__) bound_args = inspect.signature(fn).bind(self, *args, **kwargs)