add moment and nonzero.
This commit is contained in:
parent
9593e82e56
commit
374e9e199d
|
@ -19,12 +19,13 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import functional as F
|
||||
from ..cell import Cell
|
||||
from ...common import dtype as mstype
|
||||
from ..._checkparam import Validator as validator
|
||||
|
||||
|
||||
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul']
|
||||
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul', 'Moments']
|
||||
|
||||
|
||||
class ReduceLogSumExp(Cell):
|
||||
|
@ -451,3 +452,65 @@ class MatMul(Cell):
|
|||
matmul_broadcast = self.squeeze_right_op(matmul_broadcast)
|
||||
|
||||
return matmul_broadcast
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_dtype(param_name, input_dtype, allow_dtypes, cls_name):
|
||||
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||
|
||||
|
||||
class Moments(Cell):
|
||||
"""
|
||||
Calculate the mean and variance of `x`.
|
||||
|
||||
Args:
|
||||
axis (Union[int, tuple(int)]): Calculates the mean and variance along the specified axis. Default: ().
|
||||
keep_dims (bool): If true, The dimension of mean and variance are identical with input's.
|
||||
If false, don't keep these dimensions. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The tensor to be calculated. Only float16 and float32 are supported.
|
||||
|
||||
Outputs:
|
||||
- **mean** (Tensor) - The mean of input x, with the same date type as input x.
|
||||
- **variance** (Tensor) - The variance of input x, with the same date type as input x.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.Moments(axis=3, keep_dims=True)
|
||||
>>> input_x = Tensor(np.array([[[[1, 2, 3, 4], [3, 4, 5, 6]]]]), mindspore.float32)
|
||||
>>> mean, var = net(input_x)
|
||||
mean: [[[[2.5], [4.5]]]]
|
||||
var: [[[[1.25], [1.25]]]]
|
||||
"""
|
||||
|
||||
def __init__(self, axis=None, keep_dims=None):
|
||||
super(Moments, self).__init__()
|
||||
if axis is None:
|
||||
axis = ()
|
||||
if isinstance(axis, tuple):
|
||||
for idx, item in enumerate(axis):
|
||||
validator.check_value_type("axis[%d]" % idx, item, [int], self.cls_name)
|
||||
self.axis = validator.check_value_type('axis', axis, [int, tuple], self.cls_name)
|
||||
if keep_dims is None:
|
||||
keep_dims = False
|
||||
self.keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], self.cls_name)
|
||||
self.cast = P.Cast()
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
||||
self.square_diff = P.SquaredDifference()
|
||||
self.squeeze = P.Squeeze(self.axis)
|
||||
|
||||
def construct(self, x):
|
||||
tensor_dtype = x.dtype
|
||||
_check_input_dtype("input x", tensor_dtype, [mstype.float16, mstype.float32], self.cls_name)
|
||||
if tensor_dtype == mstype.float16:
|
||||
x = self.cast(x, mstype.float32)
|
||||
mean = self.reduce_mean(x, self.axis)
|
||||
variance = self.reduce_mean(self.square_diff(x, F.stop_gradient(mean)), self.axis)
|
||||
if not self.keep_dims:
|
||||
mean = self.squeeze(mean)
|
||||
variance = self.squeeze(variance)
|
||||
if tensor_dtype == mstype.float16:
|
||||
mean = self.cast(mean, mstype.float16)
|
||||
variance = self.cast(variance, mstype.float16)
|
||||
return mean, variance
|
||||
return mean, variance
|
||||
|
|
|
@ -27,6 +27,7 @@ from .multitype_ops.add_impl import hyper_add
|
|||
from .multitype_ops.ones_like_impl import ones_like
|
||||
from .multitype_ops.zeros_like_impl import zeros_like
|
||||
from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial
|
||||
from .math_ops import count_nonzero
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -47,4 +48,5 @@ __all__ = [
|
|||
'gamma',
|
||||
'poisson',
|
||||
'multinomial',
|
||||
'clip_by_value',]
|
||||
'clip_by_value',
|
||||
'count_nonzero']
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""math Operations."""
|
||||
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import functional as F
|
||||
from .. import operations as P
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_validate_axis(axis, name):
|
||||
if isinstance(axis, (tuple, list)):
|
||||
for idx, item in enumerate(axis):
|
||||
validator.check_value_type("axis[%d]" % idx, item, [int], name)
|
||||
axis = validator.check_value_type('axis', axis, [int, tuple, list], name)
|
||||
return axis
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_validate_keepdims(keep_dims, name):
|
||||
keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name)
|
||||
return keep_dims
|
||||
|
||||
|
||||
def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
|
||||
"""
|
||||
Count number of nonzero elements across axis of input tensor
|
||||
|
||||
Args:
|
||||
- **x** (Tensor[Number]) - Input data is used to count non-zero numbers.
|
||||
- **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Only constant value is allowed.
|
||||
Default: (), reduce all dimensions.
|
||||
|
||||
- **keep_dims** (bool) - If true, keep these reduced dimensions and the length is 1.
|
||||
If false, don't keep these dimensions. Default: False.
|
||||
- **dtype** (Union[Number, mstype.bool_]) - The data type of the output tensor. Only constant value is allowed.
|
||||
Default: mstype.int32
|
||||
|
||||
Returns:
|
||||
Tensor, number of nonzero element. The data type is dtype.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = Tensor(np.array([[0, 1, 0], [1, 1, 0]]).astype(np.float32))
|
||||
>>> nonzero_num = count_nonzero(x=input_x, axis=[0, 1], keep_dims=True, dtype=mstype.int32)
|
||||
nonzero_num: [[3]]
|
||||
"""
|
||||
|
||||
const_utils.check_valid_type(F.dtype(x), mstype.number_type, 'input x')
|
||||
axis = _check_validate_axis(axis, "count_nonzero")
|
||||
keep_dims = _check_validate_keepdims(keep_dims, "count_nonzero")
|
||||
const_utils.check_valid_type(dtype, mstype.number_type + (mstype.bool_,), 'dtype')
|
||||
|
||||
not_equal = P.NotEqual()
|
||||
cast = P.Cast()
|
||||
reduce_sum = P.ReduceSum(keep_dims)
|
||||
nonzero_bool = not_equal(x, 0)
|
||||
# ReduceSum only support float16 or float32 tensor.
|
||||
nonzero_val = cast(nonzero_bool, mstype.float16)
|
||||
nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype)
|
||||
|
||||
return nonzero_num
|
|
@ -241,9 +241,9 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
|
|||
- **max** (Tensor) - Value of the max range of the input data x.
|
||||
|
||||
Outputs:
|
||||
- **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape date type as input x.
|
||||
- **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape date type as input min.
|
||||
- **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape date type as input max.
|
||||
- **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x.
|
||||
- **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min.
|
||||
- **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max.
|
||||
|
||||
Examples:
|
||||
>>> gradients = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
||||
|
@ -356,9 +356,9 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
|
|||
- **max** (Tensor) - Value of the max range of the input data x.
|
||||
|
||||
Outputs:
|
||||
- **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape date type as input x.
|
||||
- **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape date type as input min.
|
||||
- **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape date type as input max.
|
||||
- **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x.
|
||||
- **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min.
|
||||
- **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max.
|
||||
|
||||
Examples:
|
||||
>>> gradients = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
|
||||
|
|
|
@ -489,6 +489,7 @@ class DynamicShape(Primitive):
|
|||
self.add_prim_attr('is_dynamic_shape', True)
|
||||
self.add_prim_attr("dynamic_shape_depends", [0])
|
||||
|
||||
|
||||
class Squeeze(PrimitiveWithInfer):
|
||||
"""
|
||||
Returns a tensor with the same type but dimensions of 1 are removed based on `axis`.
|
||||
|
|
|
@ -26,7 +26,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.operations._quant_ops import FakeQuantWithMinMaxVars, FakeQuantWithMinMaxVarsPerChannel
|
||||
from mindspore.ops.operations import _quant_ops as Q
|
||||
from ..ut_filter import non_graph_engine
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
|
@ -216,6 +216,31 @@ class HistogramSummaryNet(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class Moments(nn.Cell):
|
||||
"""Moments net definition"""
|
||||
|
||||
def __init__(self, axis=None, keep_dims=None):
|
||||
super(Moments, self).__init__()
|
||||
self.moments = nn.Moments(axis=axis, keep_dims=keep_dims)
|
||||
|
||||
def construct(self, input_x):
|
||||
mean, variance = self.moments(input_x)
|
||||
return mean, variance
|
||||
|
||||
|
||||
class CountNonZero(nn.Cell):
|
||||
"""CountNonZero net definition"""
|
||||
|
||||
def __init__(self, axis, keep_dims, dtype):
|
||||
super(CountNonZero, self).__init__()
|
||||
self.axis = axis
|
||||
self.keep_dims = keep_dims
|
||||
self.dtype = dtype
|
||||
def construct(self, input_x):
|
||||
nonzero_num = C.count_nonzero(input_x, self.axis, self.keep_dims, self.dtype)
|
||||
return nonzero_num
|
||||
|
||||
|
||||
class ScatterUpdate(nn.Cell):
|
||||
"""ScatterUpdate net definition"""
|
||||
|
||||
|
@ -1057,14 +1082,22 @@ test_case_math_ops = [
|
|||
'desc_inputs': [Tensor(np.array([[True, False, False], [False, True, True]])),
|
||||
[2, 3], [2, 3]],
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
('Moments', {
|
||||
'block': Moments(axis=(), keep_dims=False),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('CountNonZero', {
|
||||
'block': CountNonZero(axis=(), keep_dims=False, dtype=mstype.int32),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('FakeQuantWithMinMaxVars', {
|
||||
'block': FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False),
|
||||
'block': Q.FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 5), mstype.float32),
|
||||
Tensor(np.array([-6]), mstype.float32),
|
||||
Tensor(np.array([6]), mstype.float32)],
|
||||
'desc_bprop': [Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)]}),
|
||||
('FakeQuantWithMinMaxVarsPerChannel', {
|
||||
'block': FakeQuantWithMinMaxVarsPerChannel(num_bits=8, narrow_range=False),
|
||||
'block': Q.FakeQuantWithMinMaxVarsPerChannel(num_bits=8, narrow_range=False),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4), mstype.float32),
|
||||
Tensor(np.array([-6, -1, -2, -3]), mstype.float32),
|
||||
Tensor(np.array([6, 1, 2, 3]), mstype.float32)],
|
||||
|
|
Loading…
Reference in New Issue