diff --git a/mindspore/core/ops/is_nan.cc b/mindspore/core/ops/is_nan.cc new file mode 100644 index 00000000000..8480b988799 --- /dev/null +++ b/mindspore/core/ops/is_nan.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "ops/is_nan.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + return std::make_shared(x_shape); +} + +TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) { + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + CheckAndConvertUtils::CheckTensorTypeValid( + "x", input_args[0]->BuildType(), + {kBool, kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32, kFloat64, kUInt8, kUInt16, kUInt32, kUInt64}, + primitive->name()); + return kBool; +} +} // namespace + +AbstractBasePtr IsNanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t input_num = 1; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); +} +REGISTER_PRIMITIVE_EVAL_IMPL(IsNan, prim::kPrimIsNan, IsNanInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/is_nan.h b/mindspore/core/ops/is_nan.h new file mode 100644 index 00000000000..46d90bf2c9e --- /dev/null +++ b/mindspore/core/ops/is_nan.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_IS_NAN_H_ +#define MINDSPORE_CORE_OPS_IS_NAN_H_ +#include +#include + +#include "ops/primitive_c.h" +#include "ops/op_utils.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameIsNan = "IsNan"; +class IsNan : public PrimitiveC { + public: + IsNan() : PrimitiveC(kNameIsNan) { InitIOName({"x"}, {"y"}); } + ~IsNan() = default; + MS_DECLARE_PARENT(IsNan, PrimitiveC); +}; + +AbstractBasePtr IsNanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimIsNanPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_IS_NAN_H_ diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py index 9ff0a7a52e0..3b88631eacd 100644 --- a/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/ops/_op_impl/aicpu/__init__.py @@ -35,6 +35,7 @@ from .print_tensor import _print_aicpu from .topk import _top_k_aicpu from .is_finite import _is_finite_aicpu from .is_inf import _is_inf_aicpu +from .is_nan import _is_nan_aicpu from .reshape import _reshape_aicpu from .flatten import _flatten_aicpu from .squeeze import _squeeze_aicpu diff --git a/mindspore/ops/_op_impl/aicpu/is_nan.py b/mindspore/ops/_op_impl/aicpu/is_nan.py new file mode 100644 index 00000000000..95d7a3ce5a0 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/is_nan.py @@ -0,0 +1,31 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""IsNan op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +is_nan_op_info = AiCPURegOp("IsNan") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.F16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F64_Default, DataType.BOOL_Default) \ + .get_op_info() + +@op_info_register(is_nan_op_info) +def _is_nan_aicpu(): + """IsNan AiCPU register""" + return diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index e701c1ee561..fc949e0d502 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -3974,7 +3974,7 @@ class LogicalOr(_LogicBinaryOp): return None -class IsNan(PrimitiveWithInfer): +class IsNan(Primitive): r""" Determines which elements are NaN for each position. @@ -4005,7 +4005,7 @@ class IsNan(PrimitiveWithInfer): >>> x = Tensor(np.array([np.log(-1), 1, np.log(0)]), mindspore.float32) >>> output = is_nan(x) >>> print(output) - [True False False] + [ True False False] """ @prim_attr_register @@ -4013,12 +4013,6 @@ class IsNan(PrimitiveWithInfer): """Initialize IsNan""" self.init_prim_io_names(inputs=['x'], outputs=['output']) - def infer_shape(self, x_shape): - return x_shape - - def infer_dtype(self, x_dtype): - return mstype.tensor_type(mstype.bool_) - class IsInf(Primitive): r""" diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 779c03eb3cb..ab704338649 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1183,6 +1183,11 @@ test_case_math_ops = [ 'desc_inputs': [[2, 512, 56, 56]], 'desc_bprop': [[2, 512, 56, 56]], 'skip': ['backward']}), + ('IsNan', { + 'block': P.IsNan(), + 'desc_inputs': [Tensor(np.array([np.log(-1), 1, np.log(0)]).astype(np.float32))], + 'desc_bprop': [], + 'skip': ['backward']}), ('InplaceAdd', { 'block': InplaceAddNet(), 'desc_inputs': [Tensor(np.array([[1, 2], [3, 4], [5, 6]]).astype(np.float32)),