!7390 Add nonzero and moments

Merge pull request !7390 from jiangzhenguang/add_nonzero_and_moments
This commit is contained in:
mindspore-ci-bot 2020-10-22 17:10:41 +08:00 committed by Gitee
commit 22bde4bd9b
6 changed files with 185 additions and 11 deletions

View File

@ -19,12 +19,13 @@ from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.tensor import Tensor
from mindspore.ops.primitive import constexpr
from mindspore.ops import functional as F
from ..cell import Cell
from ...common import dtype as mstype
from ..._checkparam import Validator as validator
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul']
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul', 'Moments']
class ReduceLogSumExp(Cell):
@ -450,3 +451,65 @@ class MatMul(Cell):
matmul_broadcast = self.squeeze_right_op(matmul_broadcast)
return matmul_broadcast
@constexpr
def _check_input_dtype(param_name, input_dtype, allow_dtypes, cls_name):
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
class Moments(Cell):
"""
Calculate the mean and variance of `x`.
Args:
axis (Union[int, tuple(int)]): Calculates the mean and variance along the specified axis. Default: ().
keep_dims (bool): If true, The dimension of mean and variance are identical with input's.
If false, don't keep these dimensions. Default: False.
Inputs:
- **input_x** (Tensor) - The tensor to be calculated. Only float16 and float32 are supported.
Outputs:
- **mean** (Tensor) - The mean of input x, with the same date type as input x.
- **variance** (Tensor) - The variance of input x, with the same date type as input x.
Examples:
>>> net = nn.Moments(axis=3, keep_dims=True)
>>> input_x = Tensor(np.array([[[[1, 2, 3, 4], [3, 4, 5, 6]]]]), mindspore.float32)
>>> mean, var = net(input_x)
mean: [[[[2.5], [4.5]]]]
var: [[[[1.25], [1.25]]]]
"""
def __init__(self, axis=None, keep_dims=None):
super(Moments, self).__init__()
if axis is None:
axis = ()
if isinstance(axis, tuple):
for idx, item in enumerate(axis):
validator.check_value_type("axis[%d]" % idx, item, [int], self.cls_name)
self.axis = validator.check_value_type('axis', axis, [int, tuple], self.cls_name)
if keep_dims is None:
keep_dims = False
self.keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], self.cls_name)
self.cast = P.Cast()
self.reduce_mean = P.ReduceMean(keep_dims=True)
self.square_diff = P.SquaredDifference()
self.squeeze = P.Squeeze(self.axis)
def construct(self, x):
tensor_dtype = x.dtype
_check_input_dtype("input x", tensor_dtype, [mstype.float16, mstype.float32], self.cls_name)
if tensor_dtype == mstype.float16:
x = self.cast(x, mstype.float32)
mean = self.reduce_mean(x, self.axis)
variance = self.reduce_mean(self.square_diff(x, F.stop_gradient(mean)), self.axis)
if not self.keep_dims:
mean = self.squeeze(mean)
variance = self.squeeze(variance)
if tensor_dtype == mstype.float16:
mean = self.cast(mean, mstype.float16)
variance = self.cast(variance, mstype.float16)
return mean, variance
return mean, variance

View File

@ -27,6 +27,7 @@ from .multitype_ops.add_impl import hyper_add
from .multitype_ops.ones_like_impl import ones_like
from .multitype_ops.zeros_like_impl import zeros_like
from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial
from .math_ops import count_nonzero
__all__ = [
@ -47,4 +48,5 @@ __all__ = [
'gamma',
'poisson',
'multinomial',
'clip_by_value',]
'clip_by_value',
'count_nonzero']

View File

@ -0,0 +1,75 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""math Operations."""
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
from mindspore.common import dtype as mstype
from mindspore._checkparam import Validator as validator
from mindspore.ops.primitive import constexpr
from mindspore.ops import functional as F
from .. import operations as P
@constexpr
def _check_validate_axis(axis, name):
if isinstance(axis, (tuple, list)):
for idx, item in enumerate(axis):
validator.check_value_type("axis[%d]" % idx, item, [int], name)
axis = validator.check_value_type('axis', axis, [int, tuple, list], name)
return axis
@constexpr
def _check_validate_keepdims(keep_dims, name):
keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name)
return keep_dims
def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
"""
Count number of nonzero elements across axis of input tensor
Args:
- **x** (Tensor[Number]) - Input data is used to count non-zero numbers.
- **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Only constant value is allowed.
Default: (), reduce all dimensions.
- **keep_dims** (bool) - If true, keep these reduced dimensions and the length is 1.
If false, don't keep these dimensions. Default: False.
- **dtype** (Union[Number, mstype.bool_]) - The data type of the output tensor. Only constant value is allowed.
Default: mstype.int32
Returns:
Tensor, number of nonzero element. The data type is dtype.
Examples:
>>> input_tensor = Tensor(np.array([[0, 1, 0], [1, 1, 0]]).astype(np.float32))
>>> nonzero_num = count_nonzero(x=input_x, axis=[0, 1], keep_dims=True, dtype=mstype.int32)
nonzero_num: [[3]]
"""
const_utils.check_valid_type(F.dtype(x), mstype.number_type, 'input x')
axis = _check_validate_axis(axis, "count_nonzero")
keep_dims = _check_validate_keepdims(keep_dims, "count_nonzero")
const_utils.check_valid_type(dtype, mstype.number_type + (mstype.bool_,), 'dtype')
not_equal = P.NotEqual()
cast = P.Cast()
reduce_sum = P.ReduceSum(keep_dims)
nonzero_bool = not_equal(x, 0)
# ReduceSum only support float16 or float32 tensor.
nonzero_val = cast(nonzero_bool, mstype.float16)
nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype)
return nonzero_num

View File

@ -241,9 +241,9 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
- **max** (Tensor) - Value of the max range of the input data x.
Outputs:
- **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape date type as input x.
- **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape date type as input min.
- **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape date type as input max.
- **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x.
- **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min.
- **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max.
Examples:
>>> gradients = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
@ -356,9 +356,9 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
- **max** (Tensor) - Value of the max range of the input data x.
Outputs:
- **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape date type as input x.
- **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape date type as input min.
- **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape date type as input max.
- **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x.
- **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min.
- **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max.
Examples:
>>> gradients = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)

View File

@ -489,6 +489,7 @@ class DynamicShape(Primitive):
self.add_prim_attr('is_dynamic_shape', True)
self.add_prim_attr("dynamic_shape_depends", [0])
class Squeeze(PrimitiveWithInfer):
"""
Returns a tensor with the same type but dimensions of 1 are removed based on `axis`.

View File

@ -26,7 +26,7 @@ from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations._quant_ops import FakeQuantWithMinMaxVars, FakeQuantWithMinMaxVarsPerChannel
from mindspore.ops.operations import _quant_ops as Q
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
@ -216,6 +216,31 @@ class HistogramSummaryNet(nn.Cell):
return out
class Moments(nn.Cell):
"""Moments net definition"""
def __init__(self, axis=None, keep_dims=None):
super(Moments, self).__init__()
self.moments = nn.Moments(axis=axis, keep_dims=keep_dims)
def construct(self, input_x):
mean, variance = self.moments(input_x)
return mean, variance
class CountNonZero(nn.Cell):
"""CountNonZero net definition"""
def __init__(self, axis, keep_dims, dtype):
super(CountNonZero, self).__init__()
self.axis = axis
self.keep_dims = keep_dims
self.dtype = dtype
def construct(self, input_x):
nonzero_num = C.count_nonzero(input_x, self.axis, self.keep_dims, self.dtype)
return nonzero_num
class ScatterUpdate(nn.Cell):
"""ScatterUpdate net definition"""
@ -1057,14 +1082,22 @@ test_case_math_ops = [
'desc_inputs': [Tensor(np.array([[True, False, False], [False, True, True]])),
[2, 3], [2, 3]],
'desc_bprop': [[2, 3]]}),
('Moments', {
'block': Moments(axis=(), keep_dims=False),
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))],
'skip': ['backward']}),
('CountNonZero', {
'block': CountNonZero(axis=(), keep_dims=False, dtype=mstype.int32),
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))],
'skip': ['backward']}),
('FakeQuantWithMinMaxVars', {
'block': FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False),
'block': Q.FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False),
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 5), mstype.float32),
Tensor(np.array([-6]), mstype.float32),
Tensor(np.array([6]), mstype.float32)],
'desc_bprop': [Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)]}),
('FakeQuantWithMinMaxVarsPerChannel', {
'block': FakeQuantWithMinMaxVarsPerChannel(num_bits=8, narrow_range=False),
'block': Q.FakeQuantWithMinMaxVarsPerChannel(num_bits=8, narrow_range=False),
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4), mstype.float32),
Tensor(np.array([-6, -1, -2, -3]), mstype.float32),
Tensor(np.array([6, 1, 2, 3]), mstype.float32)],