!16699 add isnan cpu op & graph mode equal registered types

From: @jachua
Reviewed-by: @liangchenghui,@c_34
Signed-off-by: @wuxuejian
This commit is contained in:
mindspore-ci-bot 2021-05-27 10:01:49 +08:00 committed by Gitee
commit 17f666545c
13 changed files with 313 additions and 23 deletions

View File

@ -0,0 +1,93 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/isnan_cpu_kernel.h"
#include <cmath>
#include "abstract/utils.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
void IsNanCPUKernel::InitKernel(const CNodePtr &kernelNode) {
MS_EXCEPTION_IF_NULL(kernelNode);
size_t input_num = AnfAlgo::GetInputTensorNum(kernelNode);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but IsNanCPUKernel needs 1 inputs.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernelNode);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but IsNanCPUKernel needs 1 output.";
}
input_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernelNode, 0);
if (dtype_map_.find(input_dtype_) == dtype_map_.end()) {
MS_LOG(EXCEPTION) << "Unsupported input type found.";
}
}
bool IsNanCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (input_dtype_ == kNumberTypeFloat16) {
LaunchKernelFloat16(inputs, outputs);
} else if (input_dtype_ == kNumberTypeFloat32 || input_dtype_ == kNumberTypeFloat) {
LaunchKernelFloat<float>(inputs, outputs);
} else if (input_dtype_ == kNumberTypeFloat64) {
LaunchKernelFloat<double>(inputs, outputs);
} else if (dtype_map_.find(input_dtype_) != dtype_map_.end()) {
LaunchKernelOther(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "Only support bool, int, uint, float, but actual data type is " << TypeIdLabel(input_dtype_);
}
return true;
}
void IsNanCPUKernel::LaunchKernelFloat16(const std::vector<AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
float16 *input = reinterpret_cast<float16 *>(inputs[0]->addr);
bool *output = reinterpret_cast<bool *>(outputs[0]->addr);
size_t elem_num = inputs[0]->size / sizeof(float16);
for (size_t i = 0; i < elem_num; i++) {
float temp_num = static_cast<float>(input[i]);
output[i] = std::isnan(temp_num);
}
}
template <typename T>
void IsNanCPUKernel::LaunchKernelFloat(const std::vector<AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
T *input = reinterpret_cast<T *>(inputs[0]->addr);
bool *output = reinterpret_cast<bool *>(outputs[0]->addr);
size_t elem_num = inputs[0]->size / sizeof(T);
for (size_t i = 0; i < elem_num; i++) {
output[i] = std::isnan(input[i]);
}
}
void IsNanCPUKernel::LaunchKernelOther(const std::vector<AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
bool *output = reinterpret_cast<bool *>(outputs[0]->addr);
auto type_iter = dtype_map_.find(input_dtype_);
size_t elem_num = inputs[0]->size / (type_iter->second);
for (size_t i = 0; i < elem_num; i++) {
output[i] = false;
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,80 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ISNAN_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ISNAN_CPU_KERNEL_H_
#include <vector>
#include <map>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class IsNanCPUKernel : public CPUKernel {
public:
IsNanCPUKernel() = default;
~IsNanCPUKernel() override = default;
void InitKernel(const CNodePtr &kernelNode) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
template <typename T>
void LaunchKernelFloat(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
void LaunchKernelOther(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
void LaunchKernelFloat16(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
private:
std::map<TypeId, size_t> dtype_map_ = {{kNumberTypeBool, sizeof(bool)}, {kNumberTypeInt8, sizeof(int8_t)},
{kNumberTypeInt16, sizeof(int16_t)}, {kNumberTypeInt32, sizeof(int32_t)},
{kNumberTypeInt64, sizeof(int64_t)}, {kNumberTypeFloat16, sizeof(float16)},
{kNumberTypeFloat32, sizeof(float)}, {kNumberTypeFloat64, sizeof(double)},
{kNumberTypeUInt8, sizeof(uint8_t)}, {kNumberTypeUInt16, sizeof(uint16_t)},
{kNumberTypeUInt32, sizeof(uint32_t)}, {kNumberTypeUInt64, sizeof(uint64_t)}};
TypeId input_dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(IsNan, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), IsNanCPUKernel);
MS_REG_CPU_KERNEL(IsNan, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), IsNanCPUKernel);
MS_REG_CPU_KERNEL(IsNan, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), IsNanCPUKernel);
MS_REG_CPU_KERNEL(IsNan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), IsNanCPUKernel);
MS_REG_CPU_KERNEL(IsNan, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), IsNanCPUKernel);
MS_REG_CPU_KERNEL(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), IsNanCPUKernel);
MS_REG_CPU_KERNEL(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), IsNanCPUKernel);
MS_REG_CPU_KERNEL(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), IsNanCPUKernel);
MS_REG_CPU_KERNEL(IsNan, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), IsNanCPUKernel);
MS_REG_CPU_KERNEL(IsNan, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), IsNanCPUKernel);
MS_REG_CPU_KERNEL(IsNan, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), IsNanCPUKernel);
MS_REG_CPU_KERNEL(IsNan, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), IsNanCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ISNAN_CPU_KERNEL_H_

View File

@ -34,6 +34,8 @@ using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractKeywordArg;
using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractListPtr;
using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractTuplePtr;
@ -64,6 +66,13 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
elems.push_back(
ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToLong(i))}));
}
} else if (args_spec_list[index]->isa<AbstractList>()) {
auto arg_list = args_spec_list[index]->cast<AbstractListPtr>();
AnfNodePtr para_list = ret_graph->add_parameter();
for (size_t i = 0; i < arg_list->size(); ++i) {
elems.push_back(
ret_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), para_list, NewValueNode(SizeToLong(i))}));
}
} else if (args_spec_list[index]->isa<AbstractDictionary>()) {
AbstractDictionaryPtr arg_dict = args_spec_list[index]->cast<AbstractDictionaryPtr>();
AnfNodePtr para_dict = ret_graph->add_parameter();

View File

@ -2060,6 +2060,10 @@ def histogram_bin_edges(a, bins=10, range=None, weights=None): # pylint: disable
not provided, `range` is simply ``(a.min(), a.max())``. Values outside
the range are ignored. The first element of the range must be less than
or equal to the second.
weights(Union[int, float, bool, list, tuple, Tensor], optional): An array of weights,
of the same shape as `a`. Each value in `a` only contributes its associated weight
towards the bin count (instead of 1). This is currently not used by any of the bin
estimators, but may be in the future.
Returns:
Tensor, the edges to pass into `histogram`.
@ -2076,6 +2080,11 @@ def histogram_bin_edges(a, bins=10, range=None, weights=None): # pylint: disable
>>> print(np.histogram_bin_edges(arr, bins=2))
[0. 2.5 5. ]
"""
a = _to_tensor(a)
if weights is not None:
weights = _to_tensor(weights)
if F.shape(a) != F.shape(weights):
_raise_value_error('weights should have the same shape as a')
if isinstance(bins, (tuple, list, Tensor)):
bins = _to_tensor(bins)
if F.rank(bins) != 1:
@ -2084,12 +2093,15 @@ def histogram_bin_edges(a, bins=10, range=None, weights=None): # pylint: disable
if isinstance(bins, str):
# linspace does not support Tensor for num
_raise_unimplemented_error('string value for `bins` not implemented')
a = _to_tensor(a).ravel().astype(mstype.float32)
a = a.ravel().astype(mstype.float32)
if range is None:
start = F.reduce_min(a)
end = F.reduce_max(a)
else:
start, end = _to_tensor(*range)
start, end = range
if start > end:
_raise_value_error('max must be larger than min in range parameter')
start, end = _to_tensor(start, end)
no_range = (end - start) == 0
start = where(no_range, start - 0.5, start)
end = where(no_range, end + 0.5, end)

View File

@ -30,7 +30,7 @@ from .utils_const import _check_axes_range, _check_start_normalize, \
_check_is_float, _check_axis_in_range, _check_axis_type, _canonicalize_axis, \
_list_comprehensions, _check_element_int, _is_shape_empty, _type_convert, \
_tuple_slice, _expanded_shape, _seq_prod, _tuple_setitem, _iota, \
_raise_unimplemented_error, _cumprod, _get_device
_raise_unimplemented_error, _cumprod, _get_device, _check_is_int
# According to official numpy reference, the dimension of a numpy array must be less
# than 32
@ -2164,6 +2164,8 @@ def choose(a, choices, mode='clip'):
[ 10 -10 10]]
"""
a = _to_tensor(a)
if not _check_is_int(F.dtype(a)):
_raise_value_error('`a` should be an int array')
if isinstance(choices, (tuple, list)):
# broadcasts choices to the same shape if choices is a sequence
choices = _to_tensor(*choices)
@ -2183,14 +2185,10 @@ def choose(a, choices, mode='clip'):
if F.rank(a) == 0 or F.rank(choices) == 0:
_raise_value_error('input cannot be scalars')
a = broadcast_to(a, shape_choice)
dtype = F.dtype(choices)
# adjusts dtype for F.tensor_mul and F.gather_nd
a = a.astype(mstype.int32)
choices = choices.astype(mstype.int32)
a = _check_indices(F.shape(choices)[0], a, mode, allow_negative_index=False)
grid = _get_grid(F.shape(a))
indices = concatenate((a.reshape(F.shape(a) + (1,)), grid), -1)
return F.gather_nd(choices, indices).astype(dtype)
return F.gather_nd(choices, indices)
def size(a, axis=None):

View File

@ -32,7 +32,8 @@ from .dtypes import nan, pi, dtype_map, inf
from .array_creations import asarray_const, ones, zeros, empty, full, full_like, diag, \
arange, histogram_bin_edges, eye
from .array_ops import where as where_
from .array_ops import ravel, expand_dims, moveaxis, concatenate, flip, stack, atleast_1d
from .array_ops import ravel, expand_dims, moveaxis, concatenate, flip, stack, atleast_1d, \
split
from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \
_check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \
@ -40,9 +41,9 @@ from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \
_is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range, \
_check_dtype, _list_comprehensions, _tuple_setitem, _add_unit_axes, _seq_prod, \
_make_tensor, _promote_for_trigonometric, _raise_runtime_error, _max, _type_convert, \
_raise_unimplemented_error, _abs, _in
_raise_unimplemented_error, _abs, _in, _tuple_slice
from .utils import _expand, _broadcast_to, _broadcast_to_shape, _check_input_tensor, \
_to_tensor, _isnan, _to_tensor_origin_dtype
_to_tensor, _to_tensor_origin_dtype, _isnan
ZERO_TENSOR = asarray_const(0)
@ -1244,6 +1245,9 @@ def log(x, dtype=None):
def _prop_nan(fn, x1, x2):
"""Selects NaN if either element is NaN"""
if _get_device() == 'Ascend':
# F.isnan is not supported on Ascend
return fn(x1, x2)
has_nan = F.logical_or(_isnan(x1), _isnan(x2))
nan_tensor = F.fill(_promote(F.dtype(x1), F.dtype(x2)), F.shape(has_nan), nan)
res = fn(x1, x2)
@ -4132,9 +4136,15 @@ def multi_dot(arrays):
[500000. 500000. 500000. ... 500000. 500000. 500000.]
[500000. 500000. 500000. ... 500000. 500000. 500000.]]
"""
arrays = _to_tensor(*arrays)
if len(arrays) < 2:
_raise_value_error('Expecting at least 2 arrays')
if isinstance(arrays, (tuple, list)):
arrays = _to_tensor(*arrays)
else:
arrays = _to_tensor(arrays)
num = len(arrays)
arrays = F.reshape(arrays, (-1,) + _tuple_slice(F.shape(arrays), 2, None))
arrays = split(arrays, num)
if len(arrays) == 2:
return dot(*arrays)
@ -5678,7 +5688,7 @@ def invert(x, dtype=None):
output Tensor.
Returns:
Tensor or scalar, this is a scalar if both x1 and x2 are scalars.
Tensor or scalar.
Supported Platforms:
``Ascend``

View File

@ -20,7 +20,7 @@ from ..ops import functional as F
from ..common import dtype as mstype
from .utils_const import _tile_size, _add_unit_axes, _raise_type_error, _type_convert, \
_tuple_setitem, _callable_const
_tuple_setitem, _callable_const, _check_is_float
def _deep_list(array_like):
@ -154,11 +154,6 @@ def _get_dtype_from_scalar(*input_numbers):
return mstype.float32
def _isnan(x):
"""Computes isnan."""
return F.not_equal(x, x)
def _convert_bool_to_int(tensor):
"""Convert tensor with bool type to int32."""
if tensor.dtype == mstype.bool_:
@ -206,3 +201,9 @@ def _callable(tensor, obj):
if F.isconstant(tensor):
return isinstance(obj, types.FunctionType)
return _callable_const(F.typeof(obj))
def _isnan(x):
if _check_is_float(F.dtype(x)):
return F.isnan(x)
return F.fill(mstype.bool_, F.shape(x), False)

View File

@ -296,3 +296,33 @@ def _none_equal_list(x, y):
bool, return false.
"""
return False
@equal.register("Number", "String")
def _number_equal_string(x, y):
"""
Determine if number equal string.
Args:
x (Number): The first input which is a number.
y (String): The second input which is a string.
Returns:
bool, return false.
"""
return False
@equal.register("String", "Number")
def _string_equal_number(x, y):
"""
Determine if number equal string.
Args:
x (String): The first input which is a string.
y (Number): The second input which is a number.
Returns:
bool, return false.
"""
return False

View File

@ -79,6 +79,7 @@ check_bprop = P.CheckBprop()
equal = P.Equal()
not_equal = P.NotEqual()
isfinite = P.IsFinite()
isnan = P.IsNan()
assign_sub = P.AssignSub()
assign_add = P.AssignAdd()
assign = P.Assign()

View File

@ -3293,7 +3293,7 @@ class IsNan(PrimitiveWithInfer):
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``GPU``
``GPU`` ``CPU``
Examples:
>>> is_nan = ops.IsNan()

View File

@ -1256,8 +1256,8 @@ def test_select():
def test_choose():
x = rand_int(2, 1, 4).astype(onp.int32)
y = rand_int(3, 2, 5, 4).astype(onp.int32)
match_res(mnp.choose, onp.choose, x, y, mode='wrap')
match_res(mnp.choose, onp.choose, x, y, mode='clip')
match_res(mnp.choose, onp.choose, x, y, mode='wrap', dtype=mnp.int32)
match_res(mnp.choose, onp.choose, x, y, mode='clip', dtype=mnp.int32)
x = rand_int(5, 3, 1, 7).astype(onp.int32)
y1 = rand_int(7).astype(onp.int32)

View File

@ -920,7 +920,7 @@ def onp_maximum(x1, x2):
return onp.maximum(x1, x2)
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard

View File

@ -0,0 +1,56 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class Netnan(nn.Cell):
def __init__(self):
super(Netnan, self).__init__()
self.isnan = P.IsNan()
def construct(self, x):
return self.isnan(x)
x1 = np.array([[1.2, 2, np.nan, 88]]).astype(np.float32)
x2 = np.array([[np.inf, 1, 88.0, 0]]).astype(np.float32)
x3 = np.array([[1, 2], [3, 4], [5.0, 88.0]]).astype(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_nan():
ms_isnan = Netnan()
output1 = ms_isnan(Tensor(x1))
expect1 = [[False, False, True, False]]
assert (output1.asnumpy() == expect1).all()
output2 = ms_isnan(Tensor(x2))
expect2 = [[False, False, False, False]]
assert (output2.asnumpy() == expect2).all()
output3 = ms_isnan(Tensor(x3))
expect3 = [[False, False], [False, False], [False, False]]
assert (output3.asnumpy() == expect3).all()