From 64b4b980aae5d084f5a35d52950863bb0b2eccc5 Mon Sep 17 00:00:00 2001 From: z00512249 Date: Sun, 17 Apr 2022 14:29:47 +0800 Subject: [PATCH] add lpNorm implement --- .../cpu/kernel/mkldnn/reduction_cpu_kernel.cc | 111 ++++++++++++++++++ .../cpu/kernel/mkldnn/reduction_cpu_kernel.h | 61 ++++++++++ mindspore/core/ops/lp_norm.cc | 1 - mindspore/core/utils/convert_utils_base.h | 2 +- .../mindspore/ops/operations/math_ops.py | 5 +- tests/st/ops/cpu/test_lp_norm_op.py | 67 +++++++++++ 6 files changed, 243 insertions(+), 4 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/reduction_cpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/reduction_cpu_kernel.h create mode 100644 tests/st/ops/cpu/test_lp_norm_op.py diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/reduction_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/reduction_cpu_kernel.cc new file mode 100644 index 00000000000..2974ae6d6f9 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/reduction_cpu_kernel.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/mkldnn/reduction_cpu_kernel.h" +#include +#include +#include +#include +#include +#include "utils/ms_utils.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "mindspore/core/ops/lp_norm.h" + +namespace mindspore { +namespace kernel { +namespace { +struct ReductionDescParam { + dnnl::algorithm algorithm{dnnl::algorithm::undef}; + float p_{2.0f}; + float eps_{0.0f}; +}; +} // namespace + +dnnl::reduction::desc ReductionCpuKernelMod::GetReductionDesc(const dnnl::memory::desc &src_desc, + const dnnl::memory::desc &dst_desc) { + static const std::map reduction_op_desc_map{ + {prim::kPrimLpNorm->name(), ReductionDescParam{dnnl::algorithm::reduction_norm_lp_sum, p_, eps_}}}; + const auto desc_pair = reduction_op_desc_map.find(kernel_name_); + if (desc_pair == reduction_op_desc_map.end()) { + MS_LOG(EXCEPTION) << "ReductionCpuKernelMod does not support " << kernel_name_; + } + auto desc = CreateDesc(desc_pair->second.algorithm, src_desc, dst_desc, desc_pair->second.p_, + desc_pair->second.eps_); + return desc; +} + +void ReductionCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); + const std::string p = "p"; + if (!common::AnfAlgo::HasNodeAttr(p, kernel_node)) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' has no kernel attribute: " << p; + } + p_ = LongToFloat(common::AnfAlgo::GetNodeAttr(kernel_node, p)); + const std::string eps = "epsilon"; + if (!common::AnfAlgo::HasNodeAttr(eps, kernel_node)) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' has no kernel attribute: " << eps; + } + eps_ = common::AnfAlgo::GetNodeAttr(kernel_node, eps); + auto kernel_attr = GetKernelAttrFromNode(kernel_node); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(EXCEPTION) << kernel_name_ << " does not support this kernel data type: " << kernel_attr; + } + kernel_func_ = func_list_[index].second; + std::vector input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex0); + std::vector output_shape = common::AnfAlgo::GetOutputInferShape(kernel_node, kIndex0); + // For Reduction kernel required at least 4d data shape, extend it to 4d. + while (input_shape.size() < kIndex4) { + input_shape.insert(input_shape.begin(), 1); + } + while (output_shape.size() < kIndex4) { + output_shape.insert(output_shape.begin(), 1); + } + dnnl::memory::desc src_desc = GetDefaultMemDesc(input_shape); + dnnl::memory::desc dst_desc = GetDefaultMemDesc(output_shape); + auto desc = GetReductionDesc(src_desc, dst_desc); + auto prim_desc = CreateDesc(desc, engine_); + primitive_ = CreatePrimitive(prim_desc); + AddArgument(DNNL_ARG_SRC, src_desc); + AddArgument(DNNL_ARG_DST, dst_desc); +} + +template +bool ReductionCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto input = reinterpret_cast(inputs.at(kIndex0)->addr); + auto output = reinterpret_cast(outputs.at(kIndex0)->addr); + SetArgumentHandle(DNNL_ARG_SRC, input); + SetArgumentHandle(DNNL_ARG_DST, output); + ExecutePrimitive(); + return true; +} + +std::vector> ReductionCpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &ReductionCpuKernelMod::LaunchKernel}}; + +std::vector ReductionCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LpNorm, + []() { return std::make_shared(prim::kPrimLpNorm->name()); }); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/reduction_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/reduction_cpu_kernel.h new file mode 100644 index 00000000000..35f83ecc7ce --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/reduction_cpu_kernel.h @@ -0,0 +1,61 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCTION_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCTION_CPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class ReductionCpuKernelMod : public MKLCpuKernelMod { + public: + ReductionCpuKernelMod() = default; + explicit ReductionCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {} + ~ReductionCpuKernelMod() override = default; + + // TO be Deprecated API. + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return kernel_func_(this, inputs, outputs); + } + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + using ReductionFunc = std::function &, + const std::vector &)>; + dnnl::reduction::desc GetReductionDesc(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &dst_desc); + + ReductionFunc kernel_func_; + float p_{2.0}; + float eps_{1e-12}; + static std::vector> func_list_; + std::string kernel_type_{}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCTION_CPU_KERNEL_H_ diff --git a/mindspore/core/ops/lp_norm.cc b/mindspore/core/ops/lp_norm.cc index dfad9287883..1d3378be377 100644 --- a/mindspore/core/ops/lp_norm.cc +++ b/mindspore/core/ops/lp_norm.cc @@ -14,7 +14,6 @@ * limitations under the License. */ -#include #include #include #include diff --git a/mindspore/core/utils/convert_utils_base.h b/mindspore/core/utils/convert_utils_base.h index 5d2f88c024c..5a6aa9fe6cd 100644 --- a/mindspore/core/utils/convert_utils_base.h +++ b/mindspore/core/utils/convert_utils_base.h @@ -101,7 +101,7 @@ inline float SizeToFloat(size_t v) { return static_cast(v); } inline double LongToDouble(int64_t v) { return static_cast(v); } -inline double LongToFloat(int64_t v) { return static_cast(v); } +inline float LongToFloat(int64_t v) { return static_cast(v); } inline double FloatToDouble(float v) { return static_cast(v); } diff --git a/mindspore/python/mindspore/ops/operations/math_ops.py b/mindspore/python/mindspore/ops/operations/math_ops.py index f214d8e3b51..974d66cd6d8 100644 --- a/mindspore/python/mindspore/ops/operations/math_ops.py +++ b/mindspore/python/mindspore/ops/operations/math_ops.py @@ -1320,7 +1320,7 @@ class LpNorm(Primitive): - **input** (Tensor) - Input tensor. Outputs: - Tensor, has the same dtype as `input`, which shape depends on the args axis.For example, if the size of input + Tensor, has the same dtype as `input`, which shape depends on the args axis. For example, if the size of input is (2, 3, 4), axis is [0, 1], Outputs' shape will be (4,). Raises: @@ -1334,7 +1334,7 @@ class LpNorm(Primitive): ValueError: If the length of shape of `axis` is bigger than the length of shape of `input`. Supported Platforms: - ``Ascend`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]).astype(np.float32)) @@ -1347,6 +1347,7 @@ class LpNorm(Primitive): @prim_attr_register def __init__(self, axis, p=2, keep_dims=False, epsilon=1e-12): """Initialize LpNorm""" + super().__init__("LpNorm") validator.check_value_type("p", p, [int], self.name) validator.check_value_type("axis", axis, [int, tuple, list], self.name) validator.check_value_type("keep_dims", keep_dims, [bool], self.name) diff --git a/tests/st/ops/cpu/test_lp_norm_op.py b/tests/st/ops/cpu/test_lp_norm_op.py new file mode 100644 index 00000000000..490a2bd7a0b --- /dev/null +++ b/tests/st/ops/cpu/test_lp_norm_op.py @@ -0,0 +1,67 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.nn as nn +from mindspore import Tensor, context +from mindspore.ops import operations as P + + +class LpNormNet(nn.Cell): + def __init__(self, axis, p=2, keep_dims=False, epsilon=1e-12): + super(LpNormNet, self).__init__() + self.lp_norm = P.LpNorm(axis, p, keep_dims, epsilon) + + def construct(self, input_x): + output = self.lp_norm(input_x) + return output + + +def lp_norm_np_bencmark(data_type): + """ + Feature: generate a LpNorm numpy benchmark. + Description: The input shape need to match input shape. + Expectation: match to np mindspore LpNorm. + """ + result = np.array([9.165152, 10.954452]).astype(data_type) + return result + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_x86_cpu +@pytest.mark.parametrize("data_type", [np.float32, np.float16]) +def test_lp_norm_op(data_type): + """ + Feature: Test LpNorm. + Description: The input shape need match to output shape. + Expectation: match to np benchmark. + """ + context.set_context(mode=context.GRAPH_MODE) + input_x = np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]).astype(data_type) + error = 1e-6 + if data_type == np.float16: + error = 1e-3 + benchmark_output = lp_norm_np_bencmark(data_type) + axis = [0, 1] + p = 2 + keep_dims = False + lp_norm = LpNormNet(axis, p, keep_dims) + output = lp_norm(Tensor(input_x)) + np.testing.assert_allclose(output.asnumpy(), benchmark_output, rtol=error) + context.set_context(mode=context.PYNATIVE_MODE) + output = lp_norm(Tensor(input_x)) + np.testing.assert_allclose(output.asnumpy(), benchmark_output, rtol=error)