From 374e9e199d4573484b414e4ca4d744a49880e97c Mon Sep 17 00:00:00 2001 From: jzg Date: Fri, 16 Oct 2020 14:39:35 +0800 Subject: [PATCH] add moment and nonzero. --- mindspore/nn/layer/math.py | 65 +++++++++++++++++++++- mindspore/ops/composite/__init__.py | 4 +- mindspore/ops/composite/math_ops.py | 75 ++++++++++++++++++++++++++ mindspore/ops/operations/_quant_ops.py | 12 ++--- mindspore/ops/operations/array_ops.py | 1 + tests/ut/python/ops/test_ops.py | 39 ++++++++++++-- 6 files changed, 185 insertions(+), 11 deletions(-) create mode 100644 mindspore/ops/composite/math_ops.py diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index d5f5d31b7c2..2ed051122c7 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -19,12 +19,13 @@ from mindspore.ops import operations as P from mindspore.ops.operations import _inner_ops as inner from mindspore.common.tensor import Tensor from mindspore.ops.primitive import constexpr +from mindspore.ops import functional as F from ..cell import Cell from ...common import dtype as mstype from ..._checkparam import Validator as validator -__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul'] +__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul', 'Moments'] class ReduceLogSumExp(Cell): @@ -451,3 +452,65 @@ class MatMul(Cell): matmul_broadcast = self.squeeze_right_op(matmul_broadcast) return matmul_broadcast + + +@constexpr +def _check_input_dtype(param_name, input_dtype, allow_dtypes, cls_name): + validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) + + +class Moments(Cell): + """ + Calculate the mean and variance of `x`. + + Args: + axis (Union[int, tuple(int)]): Calculates the mean and variance along the specified axis. Default: (). + keep_dims (bool): If true, The dimension of mean and variance are identical with input's. + If false, don't keep these dimensions. Default: False. + + Inputs: + - **input_x** (Tensor) - The tensor to be calculated. Only float16 and float32 are supported. + + Outputs: + - **mean** (Tensor) - The mean of input x, with the same date type as input x. + - **variance** (Tensor) - The variance of input x, with the same date type as input x. + + Examples: + >>> net = nn.Moments(axis=3, keep_dims=True) + >>> input_x = Tensor(np.array([[[[1, 2, 3, 4], [3, 4, 5, 6]]]]), mindspore.float32) + >>> mean, var = net(input_x) + mean: [[[[2.5], [4.5]]]] + var: [[[[1.25], [1.25]]]] + """ + + def __init__(self, axis=None, keep_dims=None): + super(Moments, self).__init__() + if axis is None: + axis = () + if isinstance(axis, tuple): + for idx, item in enumerate(axis): + validator.check_value_type("axis[%d]" % idx, item, [int], self.cls_name) + self.axis = validator.check_value_type('axis', axis, [int, tuple], self.cls_name) + if keep_dims is None: + keep_dims = False + self.keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], self.cls_name) + self.cast = P.Cast() + self.reduce_mean = P.ReduceMean(keep_dims=True) + self.square_diff = P.SquaredDifference() + self.squeeze = P.Squeeze(self.axis) + + def construct(self, x): + tensor_dtype = x.dtype + _check_input_dtype("input x", tensor_dtype, [mstype.float16, mstype.float32], self.cls_name) + if tensor_dtype == mstype.float16: + x = self.cast(x, mstype.float32) + mean = self.reduce_mean(x, self.axis) + variance = self.reduce_mean(self.square_diff(x, F.stop_gradient(mean)), self.axis) + if not self.keep_dims: + mean = self.squeeze(mean) + variance = self.squeeze(variance) + if tensor_dtype == mstype.float16: + mean = self.cast(mean, mstype.float16) + variance = self.cast(variance, mstype.float16) + return mean, variance + return mean, variance diff --git a/mindspore/ops/composite/__init__.py b/mindspore/ops/composite/__init__.py index a8bdd67c220..3b97058df43 100644 --- a/mindspore/ops/composite/__init__.py +++ b/mindspore/ops/composite/__init__.py @@ -27,6 +27,7 @@ from .multitype_ops.add_impl import hyper_add from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.zeros_like_impl import zeros_like from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial +from .math_ops import count_nonzero __all__ = [ @@ -47,4 +48,5 @@ __all__ = [ 'gamma', 'poisson', 'multinomial', - 'clip_by_value',] + 'clip_by_value', + 'count_nonzero'] diff --git a/mindspore/ops/composite/math_ops.py b/mindspore/ops/composite/math_ops.py new file mode 100644 index 00000000000..03b8ef6f27a --- /dev/null +++ b/mindspore/ops/composite/math_ops.py @@ -0,0 +1,75 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""math Operations.""" +from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils +from mindspore.common import dtype as mstype +from mindspore._checkparam import Validator as validator +from mindspore.ops.primitive import constexpr +from mindspore.ops import functional as F +from .. import operations as P + + +@constexpr +def _check_validate_axis(axis, name): + if isinstance(axis, (tuple, list)): + for idx, item in enumerate(axis): + validator.check_value_type("axis[%d]" % idx, item, [int], name) + axis = validator.check_value_type('axis', axis, [int, tuple, list], name) + return axis + + +@constexpr +def _check_validate_keepdims(keep_dims, name): + keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name) + return keep_dims + + +def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32): + """ + Count number of nonzero elements across axis of input tensor + + Args: + - **x** (Tensor[Number]) - Input data is used to count non-zero numbers. + - **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Only constant value is allowed. + Default: (), reduce all dimensions. + + - **keep_dims** (bool) - If true, keep these reduced dimensions and the length is 1. + If false, don't keep these dimensions. Default: False. + - **dtype** (Union[Number, mstype.bool_]) - The data type of the output tensor. Only constant value is allowed. + Default: mstype.int32 + + Returns: + Tensor, number of nonzero element. The data type is dtype. + + Examples: + >>> input_tensor = Tensor(np.array([[0, 1, 0], [1, 1, 0]]).astype(np.float32)) + >>> nonzero_num = count_nonzero(x=input_x, axis=[0, 1], keep_dims=True, dtype=mstype.int32) + nonzero_num: [[3]] + """ + + const_utils.check_valid_type(F.dtype(x), mstype.number_type, 'input x') + axis = _check_validate_axis(axis, "count_nonzero") + keep_dims = _check_validate_keepdims(keep_dims, "count_nonzero") + const_utils.check_valid_type(dtype, mstype.number_type + (mstype.bool_,), 'dtype') + + not_equal = P.NotEqual() + cast = P.Cast() + reduce_sum = P.ReduceSum(keep_dims) + nonzero_bool = not_equal(x, 0) + # ReduceSum only support float16 or float32 tensor. + nonzero_val = cast(nonzero_bool, mstype.float16) + nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype) + + return nonzero_num diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index e186ba62f8c..ad1c90d19a1 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -241,9 +241,9 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer): - **max** (Tensor) - Value of the max range of the input data x. Outputs: - - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape date type as input x. - - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape date type as input min. - - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape date type as input max. + - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x. + - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min. + - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max. Examples: >>> gradients = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) @@ -356,9 +356,9 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer): - **max** (Tensor) - Value of the max range of the input data x. Outputs: - - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape date type as input x. - - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape date type as input min. - - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape date type as input max. + - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x. + - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min. + - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max. Examples: >>> gradients = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 24c3784afa3..562d77b5f80 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -489,6 +489,7 @@ class DynamicShape(Primitive): self.add_prim_attr('is_dynamic_shape', True) self.add_prim_attr("dynamic_shape_depends", [0]) + class Squeeze(PrimitiveWithInfer): """ Returns a tensor with the same type but dimensions of 1 are removed based on `axis`. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index f0f0cd0bf8c..ad717bc227c 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -26,7 +26,7 @@ from mindspore.ops import functional as F from mindspore.ops import operations as P from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _inner_ops as inner -from mindspore.ops.operations._quant_ops import FakeQuantWithMinMaxVars, FakeQuantWithMinMaxVarsPerChannel +from mindspore.ops.operations import _quant_ops as Q from ..ut_filter import non_graph_engine from ....mindspore_test_framework.mindspore_test import mindspore_test from ....mindspore_test_framework.pipeline.forward.compile_forward \ @@ -216,6 +216,31 @@ class HistogramSummaryNet(nn.Cell): return out +class Moments(nn.Cell): + """Moments net definition""" + + def __init__(self, axis=None, keep_dims=None): + super(Moments, self).__init__() + self.moments = nn.Moments(axis=axis, keep_dims=keep_dims) + + def construct(self, input_x): + mean, variance = self.moments(input_x) + return mean, variance + + +class CountNonZero(nn.Cell): + """CountNonZero net definition""" + + def __init__(self, axis, keep_dims, dtype): + super(CountNonZero, self).__init__() + self.axis = axis + self.keep_dims = keep_dims + self.dtype = dtype + def construct(self, input_x): + nonzero_num = C.count_nonzero(input_x, self.axis, self.keep_dims, self.dtype) + return nonzero_num + + class ScatterUpdate(nn.Cell): """ScatterUpdate net definition""" @@ -1057,14 +1082,22 @@ test_case_math_ops = [ 'desc_inputs': [Tensor(np.array([[True, False, False], [False, True, True]])), [2, 3], [2, 3]], 'desc_bprop': [[2, 3]]}), + ('Moments', { + 'block': Moments(axis=(), keep_dims=False), + 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))], + 'skip': ['backward']}), + ('CountNonZero', { + 'block': CountNonZero(axis=(), keep_dims=False, dtype=mstype.int32), + 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))], + 'skip': ['backward']}), ('FakeQuantWithMinMaxVars', { - 'block': FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False), + 'block': Q.FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False), 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 5), mstype.float32), Tensor(np.array([-6]), mstype.float32), Tensor(np.array([6]), mstype.float32)], 'desc_bprop': [Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)]}), ('FakeQuantWithMinMaxVarsPerChannel', { - 'block': FakeQuantWithMinMaxVarsPerChannel(num_bits=8, narrow_range=False), + 'block': Q.FakeQuantWithMinMaxVarsPerChannel(num_bits=8, narrow_range=False), 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4), mstype.float32), Tensor(np.array([-6, -1, -2, -3]), mstype.float32), Tensor(np.array([6, 1, 2, 3]), mstype.float32)],