diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/check_numerics_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/check_numerics_cpu_kernel.cc new file mode 100644 index 00000000000..477a61ce9aa --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/check_numerics_cpu_kernel.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/check_numerics_cpu_kernel.h" +#include +#include "abstract/utils.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kCheckNumericsInputsNum = 1; +constexpr size_t kCheckNumericsOutputsNum = 1; +} // namespace + +void CheckNumericsCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); + input_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + if (dtype_map_.find(input_dtype_) == dtype_map_.end()) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the dtype of 'x' should be float16, float32 or float64, but got: " << input_dtype_; + } +} + +bool CheckNumericsCpuKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCheckNumericsInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCheckNumericsOutputsNum, kernel_name_); + if (input_dtype_ == kNumberTypeFloat16) { + LaunchKernelFloat(inputs, outputs); + } else if (input_dtype_ == kNumberTypeFloat32) { + LaunchKernelFloat(inputs, outputs); + } else if (input_dtype_ == kNumberTypeFloat64) { + LaunchKernelFloat(inputs, outputs); + } + return true; +} + +template +void CheckNumericsCpuKernelMod::CheckNanOrInf(T value) { + if (std::isnan(value)) { + MS_LOG(EXCEPTION) << ": Tensor had NaN values"; + } else if (std::isinf(value)) { + MS_LOG(EXCEPTION) << ": Tensor had Inf values"; + } +} + +template +void CheckNumericsCpuKernelMod::LaunchKernelFloat(const std::vector &inputs, + const std::vector &outputs) { + T *input = reinterpret_cast(inputs[0]->addr); + auto *output = reinterpret_cast(outputs[0]->addr); + size_t elem_num = inputs[0]->size / sizeof(T); + + for (size_t i = 0; i < elem_num; i++) { + if constexpr (std::is_same_v) { + auto value = static_cast(input[i]); + CheckNanOrInf(value); + output[i] = input[i]; + } else { + auto value = input[i]; + CheckNanOrInf(value); + output[i] = input[i]; + } + } +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CheckNumerics, CheckNumericsCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/check_numerics_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/check_numerics_cpu_kernel.h new file mode 100644 index 00000000000..53d4556bda1 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/check_numerics_cpu_kernel.h @@ -0,0 +1,63 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHECK_NUMERICS_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHECK_NUMERICS_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class CheckNumericsCpuKernelMod : public DeprecatedNativeCpuKernelMod { + public: + CheckNumericsCpuKernelMod() = default; + ~CheckNumericsCpuKernelMod() override = default; + + void InitKernel(const CNodePtr &kernelNode) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override { + static std::vector support_list = { + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}; + return support_list; + } + + private: + template + void LaunchKernelFloat(const std::vector &inputs, const std::vector &outputs); + + template + void CheckNanOrInf(T value); + + std::map dtype_map_ = { + {kNumberTypeFloat16, sizeof(float16)}, {kNumberTypeFloat32, sizeof(float)}, {kNumberTypeFloat64, sizeof(double)}}; + TypeId input_dtype_{kTypeUnknown}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHECK_NUMERICS_CPU_KERNEL_H_ diff --git a/mindspore/core/ops/check_numerics.cc b/mindspore/core/ops/check_numerics.cc new file mode 100644 index 00000000000..70230f15423 --- /dev/null +++ b/mindspore/core/ops/check_numerics.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/check_numerics.h" +#include +#include +#include +#include +#include +#include "ops/op_utils.h" +#include "mindapi/src/helper.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr CheckNumericsInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + return std::make_shared(x_shape); +} + +TypePtr CheckNumericsInferType(const PrimitivePtr &primitive, const std::vector &input_args) { + auto prim_name = primitive->name(); + (void)CheckAndConvertUtils::CheckArgs(prim_name, input_args, 0); + auto x_dtype = input_args[0]->BuildType(); + const std::set valid_types = {kFloat16, kFloat32, kFloat64}; + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, valid_types, primitive->name()); + return x_dtype; +} +} // namespace +AbstractBasePtr CheckNumericsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t kInputNum = 1; + CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name()); + auto infer_type = CheckNumericsInferType(primitive, input_args); + auto infer_shape = CheckNumericsInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} +MIND_API_BASE_IMPL(CheckNumerics, PrimitiveC, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(CheckNumerics, prim::kPrimCheckNumerics, CheckNumericsInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/check_numerics.h b/mindspore/core/ops/check_numerics.h new file mode 100644 index 00000000000..b087469fc79 --- /dev/null +++ b/mindspore/core/ops/check_numerics.h @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_CHECKNUMERICS_H_ +#define MINDSPORE_CORE_OPS_CHECKNUMERICS_H_ +#include +#include +#include +#include +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCheckNumerics = "CheckNumerics"; + +class CheckNumerics : public BaseOperator { + public: + MIND_API_BASE_MEMBER(CheckNumerics); + CheckNumerics() : BaseOperator(kNameCheckNumerics) { InitIOName({"x"}, {"y"}); } +}; + +abstract::AbstractBasePtr CheckNumericsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimCheckNumericsPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_CHECKNUMERICS_H_ diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index f09ef0966d1..51c43b08a74 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -108,6 +108,7 @@ constexpr auto kSegmentSum = "SegmentSum"; constexpr auto kSegmentMin = "SegmentMin"; constexpr auto kDynamicShape = "DynamicShape"; constexpr auto kTensorShape = "TensorShape"; +constexpr auto kCheckNumerics = "CheckNumerics"; constexpr auto kStack = "Stack"; constexpr auto kUnstack = "Unstack"; constexpr auto kTupleGetItem = "TupleGetItem"; @@ -327,6 +328,7 @@ GVAR_DEF(PrimitivePtr, kPrimStridedSlice, std::make_shared(kStridedSl GVAR_DEF(PrimitivePtr, kPrimStridedSliceGrad, std::make_shared(kStridedSliceGrad)); GVAR_DEF(PrimitivePtr, kPrimTensorShape, std::make_shared(kTensorShape)); GVAR_DEF(PrimitivePtr, kPrimDynamicShape, std::make_shared(kDynamicShape)); +GVAR_DEF(PrimitivePtr, kPrimCheckNumerics, std::make_shared(kCheckNumerics)); GVAR_DEF(PrimitivePtr, kPrimEmbeddingLookup, std::make_shared("EmbeddingLookup")); GVAR_DEF(PrimitivePtr, kPrimEmbeddingLookupCommGrad, std::make_shared("EmbeddingLookupCommGrad")); GVAR_DEF(PrimitivePtr, kPrimSize, std::make_shared("Size")); diff --git a/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py b/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py index 934322d7f58..f663b4bf52c 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py @@ -26,6 +26,7 @@ from ..operations.array_ops import MatrixDiagV3 from ..operations.array_ops import MatrixDiagPartV3 from ..operations.array_ops import MatrixSetDiagV3 from ..operations.array_ops import Triu +from ..operations.array_ops import CheckNumerics from ..operations.array_ops import SegmentMax from ..operations.array_ops import SegmentMin from ..operations.array_ops import SegmentSum @@ -227,6 +228,17 @@ def get_bprop_triu(self): return bprop +@bprop_getters.register(CheckNumerics) +def get_bprop_check_numerics(self): + """Generate bprop for CheckNumerics""" + check_numerics = CheckNumerics() + + def bprop(x_input, out, dout): + return (check_numerics(dout),) + + return bprop + + @bprop_getters.register(P.SplitV) def get_bprop_split_v(self): """Generate bprop for SplitV""" diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index 134fd213ccf..b475977994c 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -152,6 +152,7 @@ from .environ_set import _environ_set_aicpu from .environ_get import _environ_get_aicpu from .environ_destroy_all import _environ_destroy_all_aicpu from .cross import _cross_aicpu +from .check_numerics import _check_numerics_aicpu from .cummax import _cummax_aicpu from .round import _round_aicpu from .truncated_normal import _truncated_normal_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/check_numerics.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/check_numerics.py new file mode 100644 index 00000000000..6204599f99f --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/check_numerics.py @@ -0,0 +1,33 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""CheckNumerics op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +check_numerics_op_info = AiCPURegOp("CheckNumerics") \ + .fusion_type("OPAQUE") \ + .attr("message", "str") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(check_numerics_op_info) +def _check_numerics_aicpu(): + """CheckNumerics AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index d07837f4d0c..967c28e76a7 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -10,7 +10,6 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and - # limitations under the License. # ============================================================================ @@ -250,6 +249,38 @@ class SameTypeShape(PrimitiveWithInfer): return x +class CheckNumerics(Primitive): + """ + Checks a tensor for NaN and Inf values. + + Inputs: + - **x** (Tensor) - Input Tensor of any dimension. The data type is float16, float32 or float64. + + Outputs: + Tensor, has the same shape and data type as `x` if `x` has no nan or inf values. + + Raises: + TypeError: If `x` data type is not float16, float32, float64. + RuntimeError: If `x` has nan or inf values. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> x = Tensor(np.array([[1, 3], [2, 4]], dtype=np.float32)) + >>> checknumerics = ops.CheckNumerics() + >>> output = checknumerics(x) + >>> print(output) + [[1. 3.] + [2. 4.]] + """ + + @prim_attr_register + def __init__(self): + """init CheckNumerics""" + self.init_prim_io_names(inputs=['x'], outputs=['y']) + + class Cast(PrimitiveWithInfer): """ Returns a tensor with the new specified data type. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 4f7bd6b1ef5..32201cc5d31 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -37,6 +37,7 @@ from mindspore.ops.operations.math_ops import ReduceStd from mindspore.ops.operations.math_ops import Trace from mindspore.ops.operations import nn_ops as nps from mindspore.ops.operations.array_ops import Tril +from mindspore.ops.operations.array_ops import CheckNumerics from mindspore.ops.operations.array_ops import SegmentMax from mindspore.ops.operations.array_ops import SegmentMin from mindspore.ops.operations.array_ops import SegmentSum @@ -2759,6 +2760,10 @@ test_case_array_ops = [ 'block': P.Shape(), 'desc_inputs': [[3, 3, 2, 2]], 'skip': ['backward']}), + ('CheckNumerics', { + 'block': CheckNumerics(), + 'desc_inputs': [[1, 2, 3, 4]], + 'skip': ['backward']}), ('Reshape', { 'block': P.Reshape(), 'desc_const': [(64,)],