From d8382f4d57d6e0469d5f8c4eb0c8c05d584c9b5d Mon Sep 17 00:00:00 2001 From: hedongdong Date: Thu, 4 Mar 2021 16:38:07 +0800 Subject: [PATCH] move operator primitive of Centralization to _inner_ops --- mindspore/ops/operations/__init__.py | 2 +- mindspore/ops/operations/_inner_ops.py | 69 ++++++++++++++++++- mindspore/ops/operations/inner_ops.py | 68 ------------------ .../test_tbe_ops/test_centralization.py | 4 +- 4 files changed, 71 insertions(+), 72 deletions(-) diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 5baf2d138b8..91a3a4e7697 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -41,7 +41,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, TensorSummary, HistogramSummary, Print, Assert) from .control_ops import ControlDepend, GeSwitch, Merge -from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey, Centralization +from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index a73c821ee0b..4e663e4d061 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -22,7 +22,7 @@ from ...common import dtype as mstype from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register from ..operations.math_ops import _infer_shape_reduce from ...communication.management import GlobalComm - +from .. import signature as sig class ExtractImagePatches(PrimitiveWithInfer): """ @@ -815,3 +815,70 @@ class SyncBatchNorm(PrimitiveWithInfer): args_moving = {"mean": mean, "variance": variance} validator.check_tensors_dtypes_same_and_valid(args_moving, [mstype.float16, mstype.float32], self.name) return (input_x, scale, bias, input_x, input_x) + + +class Centralization(PrimitiveWithInfer): + """ + Computes centralization. y = x - mean(x, axis). + + Note: + The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`. + + Inputs: + - **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32. + - **axis** (Union[Int, Tuple(Int), List(Int)]) - The dimensions to reduce. Default: (), reduce all dimensions. + Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)). + + Outputs: + Tensor, has the same shape and dtype as the `input_x`. + + Raises: + TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType. + TypeError: If `axis` has non-Int elements. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> mindspore.set_seed(1) + >>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) + >>> centralization = ops.Centralization() + >>> output = centralization(input_x, -1) + >>> print(output) + [[ 1.1180509 -1.1180508] + [ 0.2723984 -0.2723984]] + """ + + __mindspore_signature__ = ( + sig.make_sig('input_x'), + sig.make_sig('axis', default=()) + ) + + @prim_attr_register + def __init__(self): + """Initialize Centralization""" + self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output']) + + def __infer__(self, input_x, axis): + x_shape = list(input_x['shape']) + x_dtype = input_x['dtype'] + axis_v = axis['value'] + rank = len(x_shape) + + args = {'input_x': input_x['dtype']} + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) + + if axis_v is None: + raise ValueError(f"For {self.name}, axis must be const.") + validator.check_value_type('axis', axis_v, [int, list, tuple], self.name) + + if isinstance(axis_v, int): + validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, 'axis', self.name) + elif axis: + for index, one_axis in enumerate(axis_v): + validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name) + + out = {'shape': x_shape, + 'dtype': x_dtype, + 'value': None} + return out diff --git a/mindspore/ops/operations/inner_ops.py b/mindspore/ops/operations/inner_ops.py index a90bfe88a11..efd5ef3ff1b 100644 --- a/mindspore/ops/operations/inner_ops.py +++ b/mindspore/ops/operations/inner_ops.py @@ -21,7 +21,6 @@ from ..._checkparam import Rel from ...common import dtype as mstype from ...common.dtype import tensor, dtype_to_pytype from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer -from .. import signature as sig class ScalarCast(PrimitiveWithInfer): @@ -358,70 +357,3 @@ class MakeRefKey(Primitive): def __call__(self): pass - - -class Centralization(PrimitiveWithInfer): - """ - Computes centralization. y = x - mean(x, axis). - - Note: - The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`. - - Inputs: - - **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32. - - **axis** (Union[Int, Tuple(Int), List(Int)]) - The dimensions to reduce. Default: (), reduce all dimensions. - Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)). - - Outputs: - Tensor, has the same shape and dtype as the `input_x`. - - Raises: - TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType. - TypeError: If `axis` has non-Int elements. - - Supported Platforms: - ``Ascend`` - - Examples: - >>> mindspore.set_seed(1) - >>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) - >>> centralization = ops.Centralization() - >>> output = centralization(input_x, -1) - >>> print(output) - [[ 1.1180509 -1.1180508] - [ 0.2723984 -0.2723984]] - """ - - __mindspore_signature__ = ( - sig.make_sig('input_x'), - sig.make_sig('axis', default=()) - ) - - @prim_attr_register - def __init__(self): - """Initialize Centralization""" - self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output']) - - def __infer__(self, input_x, axis): - x_shape = list(input_x['shape']) - x_dtype = input_x['dtype'] - axis_v = axis['value'] - rank = len(x_shape) - - args = {'input_x': input_x['dtype']} - validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) - - if axis_v is None: - raise ValueError(f"For {self.name}, axis must be const.") - validator.check_value_type('axis', axis_v, [int, list, tuple], self.name) - - if isinstance(axis_v, int): - validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, 'axis', self.name) - elif axis: - for index, one_axis in enumerate(axis_v): - validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name) - - out = {'shape': x_shape, - 'dtype': x_dtype, - 'value': None} - return out diff --git a/tests/st/ops/ascend/test_tbe_ops/test_centralization.py b/tests/st/ops/ascend/test_tbe_ops/test_centralization.py index 9012fa2910d..26854ebfb3f 100644 --- a/tests/st/ops/ascend/test_tbe_ops/test_centralization.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_centralization.py @@ -18,12 +18,12 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.common.api import ms_function -from mindspore.ops import operations as P +from mindspore.ops.operations import _inner_ops as inner class Net(nn.Cell): def __init__(self, axis=()): super(Net, self).__init__() - self.centralization = P.Centralization() + self.centralization = inner.Centralization() self.axis = axis @ms_function