!32927 [assistant][ops] Add DivNoNan

Merge pull request !32927 from 张渝/DivNoNan
This commit is contained in:
i-robot 2022-06-18 02:56:32 +00:00 committed by Gitee
commit 2dc90d35f4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 114 additions and 28 deletions

View File

@ -142,6 +142,7 @@ class ArithmeticCpuTypeFunc : public DeprecatedCpuKernelFunc {
void RealDiv(const T *input1, const T *input2, T *out);
void RealDivComplex(const T *input1, const T *input2, T *out);
void Div(const T *input1, const T *input2, T *out);
void DivNoNan(const T *input1, const T *input2, T *out);
void FloorDiv(const T *input1, const T *input2, T *out);
void Mod(const T *input1, const T *input2, T *out);
void FloorMod(const T *input1, const T *input2, T *out);
@ -189,6 +190,7 @@ class ArithmeticCpuTypeFunc : public DeprecatedCpuKernelFunc {
{prim::kPrimSub->name(), &ArithmeticCpuTypeFunc<T>::Sub},
{prim::kPrimMul->name(), &ArithmeticCpuTypeFunc<T>::Mul},
{prim::kPrimDiv->name(), &ArithmeticCpuTypeFunc<T>::Div},
{prim::kPrimDivNoNan->name(), &ArithmeticCpuTypeFunc<T>::DivNoNan},
{prim::kPrimMod->name(), &ArithmeticCpuTypeFunc<T>::Mod},
{prim::kPrimFloorMod->name(), &ArithmeticCpuTypeFunc<T>::FloorMod},
{prim::kPrimPow->name(), &ArithmeticCpuTypeFunc<T>::Pow},
@ -207,6 +209,7 @@ class ArithmeticCpuTypeFunc : public DeprecatedCpuKernelFunc {
{prim::kPrimRealDiv->name(), &ArithmeticCpuTypeFunc<T>::RealDivComplex},
{prim::kPrimXdivy->name(), &ArithmeticCpuTypeFunc<T>::XdivyComplex},
{prim::kPrimMul->name(), &ArithmeticCpuTypeFunc<T>::Mul},
{prim::kPrimDivNoNan->name(), &ArithmeticCpuTypeFunc<T>::DivNoNan},
{prim::kPrimAddV2->name(), &ArithmeticCpuTypeFunc<T>::AddV2}};
}
if (arithmeticMathFuncMap.find(kernel_name_) == arithmeticMathFuncMap.end()) {
@ -502,6 +505,27 @@ void ArithmeticCpuTypeFunc<T>::DivComplex(const T *input1, const T *input2, T *o
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
}
template <typename T>
void ArithmeticCpuTypeFunc<T>::DivNoNan(const T *input1, const T *input2, T *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
auto dividend = input1[iter.GetInputPosA()];
auto divisor = input2[iter.GetInputPosB()];
iter.GenNextPos();
auto zero = (T)0;
if (divisor == zero) {
out[i] = zero;
continue;
}
out[i] = dividend / divisor;
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
}
template <typename T>
void ArithmeticCpuTypeFunc<T>::FloorDiv(const T *input1, const T *input2, T *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
@ -845,6 +869,23 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFunc
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
SpecializeArithFunc<complex128>}}},
{prim::kPrimDivNoNan->name(),
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SpecializeArithFunc<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
SpecializeArithFunc<complex64>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
SpecializeArithFunc<complex128>}}},
{prim::kPrimPow->name(),
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int32_t>},
@ -1059,6 +1100,9 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Mul,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kMul); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Div,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimDiv->name()); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, DivNoNan, []() {
return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimDivNoNan->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Pow,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimPow->name()); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, RealDiv,

View File

@ -68,6 +68,7 @@ constexpr auto kMulNoNan = "MulNoNan";
constexpr auto kACos = "ACos";
constexpr auto kACosGrad = "ACosGrad";
constexpr auto kRealDiv = "RealDiv";
constexpr auto kDivNoNan = "DivNoNan";
constexpr auto kReciprocal = "Reciprocal";
constexpr auto kInv = "Inv";
constexpr auto kReduceStd = "ReduceStd";

View File

@ -16,10 +16,11 @@
#include "ops/div_no_nan.h"
#include <map>
#include <string>
#include <algorithm>
#include <complex>
#include <map>
#include <set>
#include <string>
#include "abstract/ops/primitive_infer_map.h"
#include "utils/tensor_construct_utils.h"
@ -51,26 +52,19 @@ void DivNoNanImpl(void *x1, void *x2, void *result, size_t size) {
}
abstract::ShapePtr DivNoNanInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
const int64_t valid_size = 2;
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, valid_size,
prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
MS_EXCEPTION_IF_NULL(input_args[1]);
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
return BroadCastInferShape(prim_name, input_args);
}
TypePtr DivNoNanInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
std::map<std::string, TypePtr> types;
const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kUInt8, kFloat16,
kFloat32, kFloat64, kComplex64, kComplex128};
(void)types.emplace("x1", input_args[0]->BuildType());
(void)types.emplace("x2", input_args[1]->BuildType());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return input_args[0]->BuildType();
}
@ -166,8 +160,11 @@ ValuePtr DivNoNanInferValue(const PrimitivePtr &prim, const std::vector<Abstract
MIND_API_OPERATOR_IMPL(DivNoNan, BaseOperator);
AbstractBasePtr DivNoNanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto infer_shape = DivNoNanInferShape(primitive, input_args);
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputsNum = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
auto infer_type = DivNoNanInferType(primitive, input_args);
auto infer_shape = DivNoNanInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(DivNoNan, prim::kPrimDivNoNan, DivNoNanInfer, DivNoNanInferValue, true);

View File

@ -119,6 +119,7 @@ from .uniform_real import _uniform_real_aicpu
from .standard_laplace import _standard_laplace_aicpu
from .strided_slice import _strided_slice_aicpu
from .neg import _neg_aicpu
from .div_no_nan import _div_no_nan_aicpu
from .strided_slice_grad import _strided_slice_grad_aicpu
from .end_of_sequence import _end_of_sequence_aicpu
from .fused_sparse_adam import _fused_sparse_adam_aicpu

View File

@ -0,0 +1,35 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DivNoNan op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
div_no_nan_op_info = AiCPURegOp("DivNoNan") \
.fusion_type("OPAQUE") \
.input(0, "x1", "required") \
.input(1, "x2", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
.get_op_info()
@op_info_register(div_no_nan_op_info)
def _div_no_nan_aicpu():
"""DivNoNan AiCPU register"""
return

View File

@ -2782,11 +2782,11 @@ class Div(_MathBinaryOp):
return None
class DivNoNan(_MathBinaryOp):
class DivNoNan(Primitive):
r"""
Computes a safe divide and returns 0 if the y is zero.
Computes a safe divide and returns 0 if the x2 is zero.
Inputs of `x` and `y` comply with the implicit type conversion rules to make the data types consistent.
Inputs of `x1` and `x2` comply with the implicit type conversion rules to make the data types consistent.
The inputs must be two tensors or one tensor and one scalar.
When the inputs are two tensors,
dtypes of them cannot be bool at the same time, and the shapes of them could be broadcast.
@ -2795,16 +2795,16 @@ class DivNoNan(_MathBinaryOp):
.. math::
output_{i} = \begin{cases}
0, & \text{ if } y_{i} = 0\\
x_{i} / y_{i}, & \text{ if } y_{i} \ne 0
0, & \text{ if } x2_{i} = 0\\
x1_{i} / x2_{i}, & \text{ if } x2_{i} \ne 0
\end{cases}
Inputs:
- **x** (Union[Tensor, number.Number, bool]) - The first input is a number.Number or
- **x1** (Union[Tensor, number.Number, bool]) - The first input is a number.Number or
a bool or a tensor whose data type is
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ or
`bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_.
- **y** (Union[Tensor, number.Number, bool]) - The second input is a number.Number or
- **x2** (Union[Tensor, number.Number, bool]) - The second input is a number.Number or
a bool when the first input is a tensor or a tensor whose data type is number or bool\_.
When the first input is Scalar, the second input must be a Tensor whose data type is number or bool\_.
@ -2814,24 +2814,27 @@ class DivNoNan(_MathBinaryOp):
and the data type is the one with higher precision or higher digits among the two inputs.
Raises:
TypeError: If `x` and `y` is not a number.Number or a bool or a Tensor.
TypeError: If `x1` and `x2` is not a number.Number or a bool or a Tensor.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([-1.0, 0., 1.0, 5.0, 6.0]), mindspore.float32)
>>> y = Tensor(np.array([0., 0., 0., 2.0, 3.0]), mindspore.float32)
>>> div_no_nan = ops.DivNoNan()
>>> output = div_no_nan(x, y)
>>> from mindspore.ops.operations.math_ops import DivNoNan
>>> x1 = Tensor(np.array([-1.0, 0., 1.0, 5.0, 6.0]), mindspore.float32)
>>> x2 = Tensor(np.array([0., 0., 0., 2.0, 3.0]), mindspore.float32)
>>> div_no_nan = DivNoNan()
>>> output = div_no_nan(x1, x2)
>>> print(output)
[0. 0. 0. 2.5 2. ]
"""
__mindspore_signature__ = (sig.sig_dtype.T, sig.sig_dtype.T)
@prim_attr_register
def __init__(self):
"""Initialize _BinaryOp"""
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
"""Initialize DivNoNan"""
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
class MulNoNan(_MathBinaryOp):

View File

@ -56,6 +56,7 @@ from mindspore.ops.operations.array_ops import MatrixSetDiagV3
from mindspore.ops.operations.array_ops import ScatterNdMax
from mindspore.ops.operations.math_ops import AddV2
from mindspore.ops.operations.math_ops import Lcm
from mindspore.ops.operations.math_ops import DivNoNan
from mindspore.ops.operations.math_ops import Gcd
from mindspore.ops.operations.math_ops import RaggedRange
from mindspore.ops.operations.array_ops import RangeV2
@ -1763,6 +1764,10 @@ test_case_math_ops = [
'block': P.Div(),
'desc_inputs': [[4, 5], [2, 3, 4, 5]],
'desc_bprop': [[2, 3, 4, 5]]}),
('DivNoNan', {
'block': DivNoNan(),
'desc_inputs': [[2, 3, 4, 5], [2, 3, 4, 5]],
'desc_bprop': [[2, 3, 4, 5]]}),
('Equal', {
'block': P.Equal(),
'desc_inputs': [[3, 4, 5], [4, 5]],