fix_nansum_master

This commit is contained in:
yide12 2023-01-31 16:05:22 +08:00
parent ea725736a6
commit 716369ee59
16 changed files with 202 additions and 118 deletions

View File

@ -505,6 +505,7 @@ Array操作
mindspore.ops.movedim
mindspore.ops.narrow
mindspore.ops.nan_to_num
mindspore.ops.nansum
mindspore.ops.normal
mindspore.ops.nonzero
mindspore.ops.numel

View File

@ -3,11 +3,4 @@ mindspore.Tensor.all
.. py:method:: mindspore.Tensor.all(axis=(), keep_dims=False)
检查在指定轴上所有元素是否均为True。
参数:
- **axis** (Union[None, int, tuple(int)]) - 计算all的维度。当 `axis` 为None或者空元组的时候计算所有维度。当 `axis` 为int或tuple(int)时记Tensor的维度为dim则其取值范围为[-dim, dim)。默认值:()。
- **keep_dims** (bool) - 计算结果是否保留维度。默认值False。
返回:
Tensor。如果在指定轴方向上所有数组元素都为True则其值为True否则其值为False。如果轴为None或空元组则默认降维。
详情请参考 :func:`mindspore.ops.all`

View File

@ -3,11 +3,4 @@ mindspore.Tensor.any
.. py:method:: mindspore.Tensor.any(axis=(), keep_dims=False)
检查在指定轴方向上是否存在任意为True的Tensor元素。
参数:
- **axis** (Union[None, int, tuple(int)]) - 计算any的维度。当 `axis` 为None或空元组时计算所有维度。当 `axis` 为int或tuple(int)时记Tensor的维度为dim则其取值范围为[-dim, dim)。默认值:()。
- **keep_dims** (bool) - 计算结果是否保留维度。默认值False。
返回:
Tensor。如果在指定轴方向上存在任意Tensor元素为True则其值为True否则其值为False。如果轴为None或空元组则默认降维。
详情请参考 :func:`mindspore.ops.any`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.nansum
=======================
.. py:method:: mindspore.Tensor.nansum(axis=None, keepdims=False, dtype=None)
详情请参考 :func:`mindspore.ops.nansum`

View File

@ -200,6 +200,7 @@ mindspore.Tensor
mindspore.Tensor.mT
mindspore.Tensor.multiply
mindspore.Tensor.nan_to_num
mindspore.Tensor.nansum
mindspore.Tensor.narrow
mindspore.Tensor.nbytes
mindspore.Tensor.ndim

View File

@ -0,0 +1,30 @@
mindspore.ops.nansum
====================
.. py:function:: mindspore.ops.nansum(x, axis=None, keepdims=False, *, dtype=None)
计算 `x` 指定维度元素的和,将非数字(NaNs)处理为零。
参数:
- **x** (Tensor) - 输入Tensor。
- **axis** (Union[int, tuple(int)], 可选) - 求和的维度。假设 `x` 的秩为r取值范围[-r,r)。默认值None对Tensor中的所有元素求和。
- **keepdims** (bool, 可选) - 输出Tensor是否保持维度。默认值False不保留维度。
关键字参数:
- **dtype** (:class:`mindspore.dtype`, 可选) - 输出Tensor的类型。默认值None。
返回:
Tensor输入 `x` 指定维度的元素和,将非数字(NaNs)处理为零。
- 如果 `axis` 为None`keep_dims` 为False
则输出一个零维Tensor表示输入Tensor中所有元素的和。
- 如果 `axis` 为int值为2并且 `keep_dims` 为False
则输出的shape为 :math:`(x_1, x_3, ..., x_R)`
- 如果 `axis` 为tuple(int)或list(int),值为(2, 3),并且 `keep_dims` 为False
则输出的shape为 :math:`(x_1, x_4, ..., x_R)`
异常:
- **TypeError** - `x` 不是一个Tensor。
- **TypeError** - `keepdims` 不是bool类型。
- **TypeError** - `x` 的数据类型或 `dtype` 是complex类型。
- **ValueError** - `axis` 的范围不在[-r, r)r表示 `x` 的秩。

View File

@ -206,6 +206,7 @@
mindspore.Tensor.mT
mindspore.Tensor.multiply
mindspore.Tensor.nan_to_num
mindspore.Tensor.nansum
mindspore.Tensor.narrow
mindspore.Tensor.nbytes
mindspore.Tensor.ndim

View File

@ -505,6 +505,7 @@ Array Operation
mindspore.ops.moveaxis
mindspore.ops.movedim
mindspore.ops.nan_to_num
mindspore.ops.nansum
mindspore.ops.normal
mindspore.ops.nonzero
mindspore.ops.numel

View File

@ -446,6 +446,7 @@ BuiltInTypeMap &GetMethodMap() {
{"mul", std::string("mul")}, // mul()
{"multiply", std::string("multiply")}, // multiply()
{"nan_to_num", std::string("nan_to_num")}, // nan_to_num()
{"nansum", std::string("nansum")}, // nansum()
{"neg", std::string("neg")}, // neg()
{"ne", std::string("ne")}, // ne()
{"not_equal", std::string("not_equal")}, // not_equal()

View File

@ -2024,6 +2024,13 @@ def sum_to_size(x, *size):
return x
def nansum(x, axis=None, keepdims=False, *, dtype=None):
"""
Computes sum of all elements, treating NaNs as zero.
"""
return F.nansum(x, axis=axis, keepdims=keepdims, dtype=dtype)
def repeat(x, repeats, axis=None):
"""
Repeat elements of an array.

View File

@ -861,33 +861,9 @@ class Tensor(Tensor_):
return tensor_operator_registry.get('adjoint')(self)
def all(self, axis=(), keep_dims=False):
r"""
For details, please refer to :func:`mindspore.ops.all`.
"""
Check all tensor elements along a given axis evaluate to True.
Args:
axis (Union[None, int, tuple(int)]): Dimensions of reduction.
When the axis is None or empty tuple, reduce all dimensions. When the axis is int or
tuple(int), if the dimension of Tensor is dim, the value range is [-dim, dim). Default: ().
keep_dims (bool): Whether to keep the reduced dimensions. Default: False.
Returns:
Tensor, if all tensor elements along the given axis evaluate to True, its value is True,
otherwise its value is False. If the axis is None or empty tuple, reduce all dimensions.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
See also:
:func:`mindspore.Tensor.any`: Check any tensor element along a given axis evaluate to True.
Examples:
>>> from mindspore import Tensor
>>> a = Tensor([True, True, False])
>>> output = a.all()
>>> print(output)
False
"""
self._init_check()
if axis is None:
axis = ()
@ -901,33 +877,9 @@ class Tensor(Tensor_):
return tensor_operator_registry.get('angle')(self)
def any(self, axis=(), keep_dims=False):
r"""
For details, please refer to :func:`mindspore.ops.any`.
"""
Check any tensor element along a given axis evaluate to True.
Args:
axis (Union[None, int, tuple(int)]): Dimensions of reduction.
When the axis is None or empty tuple, reduce all dimensions. When the axis is int or
tuple(int), if the dimension of Tensor is dim, the value range is [-dim, dim). Default: ().
keep_dims (bool): Whether to keep the reduced dimensions. Default: False.
Returns:
Tensor, if any tensor element along the given axis evaluates to True, its value is True,
otherwise its value is False. If the axis is None or empty tuple, reduce all dimensions.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
See also:
:func:`mindspore.Tensor.all`: Check all tensor elements along a given axis evaluate to True.
Examples:
>>> from mindspore import Tensor
>>> a = Tensor([True, True, False])
>>> output = a.any()
>>> print(output)
True
"""
self._init_check()
if axis is None:
axis = ()
@ -1347,18 +1299,21 @@ class Tensor(Tensor_):
r"""
For details, please refer to :func:`mindspore.ops.logaddexp`.
"""
self._init_check()
return tensor_operator_registry.get('logaddexp')(self, other)
def logaddexp2(self, other):
r"""
For details, please refer to :func:`mindspore.ops.logaddexp2`.
"""
self._init_check()
return tensor_operator_registry.get('logaddexp2')(self, other)
def logsumexp(self, dim, keepdim=False):
r"""
For details, please refer to :func:`mindspore.ops.logsumexp`.
"""
self._init_check()
return tensor_operator_registry.get('logsumexp')(self, dim, keepdim)
def logdet(self):
@ -1393,18 +1348,21 @@ class Tensor(Tensor_):
r"""
For details, please refer to :func:`mindspore.ops.isneginf`.
"""
self._init_check()
return tensor_operator_registry.get('isneginf')(self)
def isposinf(self):
r"""
For details, please refer to :func:`mindspore.ops.isposinf`.
"""
self._init_check()
return tensor_operator_registry.get('isposinf')(self)
def isreal(self):
r"""
For details, please refer to :func:`mindspore.ops.isreal`.
"""
self._init_check()
return tensor_operator_registry.get('isreal')(self)
def isfinite(self):
@ -2994,6 +2952,13 @@ class Tensor(Tensor_):
return x.sum(tuple(axes), keepdims=True)
return x
def nansum(self, axis=None, keepdims=False, dtype=None):
"""
For details, please refer to :func:`mindspore.ops.nansum`.
"""
self._init_check()
return tensor_operator_registry.get('nansum')(self, axis=axis, keepdims=keepdims, dtype=dtype)
def repeat(self, repeats, axis=None):
"""
Repeat elements of a tensor.

View File

@ -245,6 +245,7 @@ from .math_func import (
maximum,
median,
nan_to_num,
nansum,
logaddexp,
logaddexp2,
logit,

View File

@ -9739,76 +9739,63 @@ def imag(input):
return _get_cache_prim(P.Imag)()(input)
def nansum(x, axis, keepdims=False, *, dtype=None):
def nansum(x, axis=None, keepdims=False, *, dtype=None):
"""
Computes sum of all elements, treating Not a Numbers (NaNs) as zero.
Computes sum of `x` over a given dimension, treating NaNs as zero.
Args:
x (Tensor) - The input tensor.
axis (Union[int, tuple(int)]) - The dimensions to reduce. Must be in the range [-rank(`x`), rank(`x`)).
keepdims (bool, optional) - Whether the output tensor has dim retained or not. Default: False.
dtype (mindspore type, optional) - The desired data type of returned tensor. Default: None.
x (Tensor) - The input Tensor.
axis (Union[int, tuple(int)], optional) - The dimensions to reduce. Supposed the rank of `x` is r,
axis must be in the range [-rank(x), rank(x)). Default: None, all dimensions are reduced.
keepdims (bool, optional) - Whether the output Tensor keeps dimensions or not. Default: False.
Keyword Args:
dtype (:class:`mindspore.dtype`, optional): The dtype of output Tensor. Default: None.
Returns:
Tensor, the sum of each row of the input tensor in the given dimension dim,
treating Not a Numbers (NaNs) as zero.
Tensor, the sum of input `x` in the given dimension dim, treating NaNs as zero.
- If axis is (), keepdims is False,
the output is a 0-D tensor representing the sum of all elements in the input tensor.
- If axis is None, keepdims is False,
the output is a 0-D Tensor representing the sum of all elements in the input Tensor.
- If axis is int, set as 2, and keepdims is False,
the shape of output is :math:`(x_1, x_3, ..., x_R)`.
- If axis is tuple(int) or list(int), set as (2, 3), and keepdims is False,
the shape of output is :math:`(x_1, x_4, ..., x_R)`.
- If x_dtype or dtype is complex type, nansum does not supported.
- If x_dtype is floating-point type, and dtype is integer type, nansum does not supported.
Raises:
TypeError: If `x` is not tensor.
TypeError: If `x` is not Tensor.
TypeError: If `keepdims` is not a bool.
TypeError: If x_dtype or dtype is complex type.
TypeError: If x_dtype is floating-point type, and dtype is integer type.
valueError: If 'axis' not in [-rank(`x`), rank(`x`)).
TypeError: If the dtype of `x` or `dtype` is complex type.
ValueError: If 'axis' not in [-rank(`x`), rank(`x`)).
Supported Platforms:
``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]), mindspore.float32)
>>> axis = [0]
>>> output = ops.nansum(x, axis, dtype=mindspore.float32)
>>> print(output)
[2. 4. 6.]
>>> x = Tensor(np.array([[float("nan"), 2, 3], [1, 2, float("nan")]]), mindspore.float32)
>>> output1 = ops.nansum(x, axis=0, keepdims=False, dtype=mindspore.float32)
>>> output2 = ops.nansum(x, axis=0, keepdims=True, dtype=mindspore.float32)
>>> print(output1)
[1. 4. 3.]
>>> print(output2)
[[1. 4. 3.]]
"""
if not isinstance(x, (Tensor, Tensor_)):
raise TypeError("For nansum, input must be Tensor.")
res_dtype = dtype
dtype_op = P.DType()
x_dtype = dtype_op(x)
if (x_dtype is not None and x_dtype in (mstype.complex64, mstype.complex128)) or \
(dtype is not None and dtype in (mstype.complex64, mstype.complex128)):
raise TypeError('nansum not supported complex type.')
if x_dtype == mstype.bool_:
if not isinstance(x, Tensor):
raise TypeError(f"For nansum, input must be Tensor, but got {type(x)}.")
if x.is_complex():
raise TypeError(f'For nansum, input are not supported complex type, but got {type(x)}.')
if dtype is not None and dtype in mstype.complex_type:
raise TypeError(f'For nansum, dtype not supported complex type, but got {dtype}.')
if axis is None:
axis = ()
if x.dtype == mstype.bool_:
x = x.astype(mstype.int64)
if dtype is None:
if x_dtype not in (mstype.float32, mstype.float16, mstype.float64):
dtype = mstype.int64
else:
dtype = x_dtype
if x_dtype in (mstype.float32, mstype.float16, mstype.float64):
if dtype not in (mstype.float32, mstype.float16, mstype.float64):
raise TypeError(f'nansum not supported for this dtype {dtype} when x_dtype is floa16, float32 or float64')
get_nan = P.IsNan()(x)
x = P.MaskedFill()(x, get_nan, Tensor(0.0, dtype=x_dtype))
if x_dtype != dtype:
is_nan = _get_cache_prim(P.IsNan)()(x)
x = ops.masked_fill(x, is_nan, 0)
x = _get_cache_prim(P.ReduceSum)(keepdims)(x, axis)
if dtype is not None and x.dtype != dtype:
x = x.astype(dtype)
res = P.ReduceSum(keepdims)(x, axis)
if (res_dtype is not None) and (res_dtype == mstype.bool_):
res = res.astype(res_dtype)
return res
return x
def diag_embed(x, offset=0, dim1=-2, dim2=-1):
@ -10048,6 +10035,7 @@ __all__ = [
'mul',
'multiply',
'nan_to_num',
'nansum',
'digamma',
'lgamma',
'tensor_div',

View File

@ -328,6 +328,7 @@ tensor_operator_registry.register('argsort', argsort)
tensor_operator_registry.register('msort', msort)
tensor_operator_registry.register('mm', mm)
tensor_operator_registry.register('nan_to_num', nan_to_num)
tensor_operator_registry.register('nansum', nansum)
tensor_operator_registry.register('csr_to_coo', csr_to_coo)
tensor_operator_registry.register('zeros', zeros)
tensor_operator_registry.register('ones', ones)

View File

@ -0,0 +1,48 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x):
return ops.nansum(x, axis=0, keepdims=True, dtype=ms.int64)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_ops_nansum(mode):
"""
Feature: ops.nansum
Description: Verify the result of nansum
Expectation: success
"""
ms.set_context(mode=mode)
x = ms.Tensor([[float("nan"), 128.1, -256.9], [float("nan"), float("nan"), 128]], ms.float32)
net = Net()
output = net(x)
expect_output = [[0, 128, -128]]
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,47 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x):
return x.nansum(axis=0, keepdims=True, dtype=ms.int64)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_nansum(mode):
"""
Feature: tensor.nansum
Description: Verify the result of nansum
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor([[float("nan"), 128.1, -256.9], [float("nan"), float("nan"), 128]], ms.float32)
net = Net()
output = net(x)
expect_output = [[0, 128, -128]]
assert np.allclose(output.asnumpy(), expect_output)