diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/l2_normalize_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/l2_normalize_cpu_kernel.cc new file mode 100644 index 00000000000..66d1976f197 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/l2_normalize_cpu_kernel.cc @@ -0,0 +1,154 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/l2_normalize_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +template +void L2NormalizeCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + epsilon_ = AnfAlgo::GetNodeAttr(kernel_node, "epsilon"); + axis_ = LongToInt(AnfAlgo::GetNodeAttr(kernel_node, "axis")); + input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + CheckParam(kernel_node); + if (axis_ < 0) { + axis_ += SizeToInt(input_shape_.size()); + } +} + +template +void L2NormalizeCPUKernel::CalcDenominator(const T *input_addr, const size_t reduce_size, const int dims, + std::unique_ptr *denominator_addr) { + T temp = (T)0.0; + T epsilon = (T)epsilon_; + T denominator = (T)0.0; + // Calculate transpose axes and stride + size_t stride = 1; + std::vector axes(input_shape_.size()); + int k = 0; + for (int i = 0; i < dims; ++i) { + if (i != axis_) { + axes[k] = i; + ++k; + } else { + stride *= input_shape_[i]; + } + } + axes[k] = axis_; + + std::vector transpose_shape(input_shape_.size()); + for (int i = 0; i < dims; ++i) { + transpose_shape[i] = input_shape_[axes[i]]; + } + + TransposeIterator tran_base_iter(std::move(transpose_shape), std::move(axes), input_shape_); + + auto task = [&](size_t start, size_t end) { + auto iter = tran_base_iter; + iter.SetPos(start * stride); + for (size_t i = start; i < end; ++i) { + denominator = input_addr[iter.GetPos()]; + denominator = denominator * denominator; + iter.GenNextPos(); + for (size_t j = 1; j < stride; ++j) { + temp = input_addr[iter.GetPos()]; + denominator += temp * temp; + iter.GenNextPos(); + } + denominator = (denominator > epsilon) ? denominator : epsilon; + (*denominator_addr)[i] = sqrt(denominator); + } + }; + CPUKernelUtils::ParallelFor(task, reduce_size); +} + +template +void L2NormalizeCPUKernel::CalcOutput(const T *input_addr, const std::vector reduce_shape, + const size_t output_size, T *output_addr, + std::unique_ptr const &denominator_addr) { + BroadcastIterator broad_base_iter(input_shape_, reduce_shape, output_shape_); + auto task = [&](size_t start, size_t end) { + auto iter = broad_base_iter; + iter.SetPos(start); + for (size_t i = start; i < end; ++i) { + T dividend = input_addr[iter.GetInputPosA()]; + T divisor = denominator_addr[iter.GetInputPosB()]; + if (divisor == (T)0) { + if (dividend == (T)0) { + output_addr[i] = std::numeric_limits::quiet_NaN(); + continue; + } + if (std::numeric_limits::has_infinity) { + output_addr[i] = dividend > (T)0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); + } else { + output_addr[i] = dividend > (T)0 ? std::numeric_limits::max() : std::numeric_limits::min(); + } + continue; + } + output_addr[i] = dividend / divisor; + iter.GenNextPos(); + } + }; + CPUKernelUtils::ParallelFor(task, output_size); +} + +template +bool L2NormalizeCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + int dims = input_shape_.size(); + std::vector reduce_shape = input_shape_; + size_t reduce_size = 1; + reduce_shape[axis_] = 1; + for (int i = 0; i < dims; ++i) { + reduce_size *= reduce_shape[i]; + } + auto denominator_addr = std::make_unique(reduce_size); + + L2NormalizeCPUKernel::CalcDenominator(input_addr, reduce_size, dims, &denominator_addr); + + size_t output_size = outputs[0]->size / sizeof(T); + L2NormalizeCPUKernel::CalcOutput(input_addr, reduce_shape, output_size, output_addr, denominator_addr); + + return true; +} + +template +void L2NormalizeCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + int dims = SizeToInt(input_shape_.size()); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but L2NormalizeCPUKernel needs 1 input."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but L2NormalizeCPUKernel needs 1 output."; + } + if (axis_ < -dims || axis_ >= dims) { + MS_LOG(EXCEPTION) << "Attr axis_ " << axis_ << " must be in " << -dims << "~" << dims; + } + if (epsilon_ == 0.0) { + MS_LOG(EXCEPTION) << "Attr epsilon can not be zero."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/l2_normalize_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/l2_normalize_cpu_kernel.h new file mode 100644 index 00000000000..9e158416646 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/l2_normalize_cpu_kernel.h @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2_NORMALIZE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2_NORMALIZE_CPU_KERNEL_H_ +#include +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class L2NormalizeCPUKernel : public CPUKernel { + public: + L2NormalizeCPUKernel() = default; + ~L2NormalizeCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + void CalcDenominator(const T *input_addr, const size_t reduce_size, const int dims, + std::unique_ptr *denominator_addr); + + void CalcOutput(const T *input_addr, const std::vector reduce_shape, const size_t output_size, T *output_addr, + std::unique_ptr const &denominator_addr); + + private: + std::vector input_shape_; + std::vector output_shape_; + float epsilon_; + int axis_; + void CheckParam(const CNodePtr &kernel_node); +}; + +MS_REG_CPU_KERNEL_T(L2Normalize, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + L2NormalizeCPUKernel, float16); + +MS_REG_CPU_KERNEL_T(L2Normalize, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + L2NormalizeCPUKernel, float); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2_NORMALIZE_CPU_KERNEL_H_ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 8334a634e48..f59d74f3e4b 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3036,7 +3036,7 @@ class L2Normalize(PrimitiveWithInfer): TypeError: If dtype of `input_x` is neither float16 nor float32. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> l2_normalize = ops.L2Normalize() diff --git a/tests/st/ops/cpu/test_l2normalize_op.py b/tests/st/ops/cpu/test_l2normalize_op.py new file mode 100644 index 00000000000..53a2325f7cc --- /dev/null +++ b/tests/st/ops/cpu/test_l2normalize_op.py @@ -0,0 +1,100 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore.common.tensor import Tensor +from mindspore import dtype as mstype +from mindspore.nn import Cell +from mindspore.ops import operations as P +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class Net(Cell): + def __init__(self, axis=0, epsilon=1e-4): + super(Net, self).__init__() + self.norm = P.L2Normalize(axis=axis, epsilon=epsilon) + + def construct(self, x): + return self.norm(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_l2normalize_float32(): + x = np.arange(96).astype(np.float32).reshape(2, 3, 4, 4) + expect = x / np.sqrt(np.sum(x**2, axis=0, keepdims=True)) + x = Tensor(x) + error = np.ones(shape=[2, 3, 4, 4]) * 1.0e-5 + + norm_op = Net(axis=0) + output = norm_op(x) + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_l2normalize_float16(): + x = np.arange(96).astype(np.float16).reshape(2, 3, 4, 4) + expect = x / np.sqrt(np.sum(x**2, axis=0, keepdims=True)) + x = Tensor(x, dtype=mstype.float16) + error = np.ones(shape=[2, 3, 4, 4]) * 1.0e-3 + + norm_op = Net(axis=0) + output = norm_op(x) + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_l2normalize_axis(): + axis = -2 + x = np.arange(96).astype(np.float32).reshape(2, 3, 4, 4) + expect = x / np.sqrt(np.sum(x**2, axis=axis, keepdims=True)) + x = Tensor(x) + error = np.ones(shape=[2, 3, 4, 4]) * 1.0e-5 + + norm_op = Net(axis=axis) + output = norm_op(x) + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_l2normalize_epsilon(): + axis = -1 + epsilon = 900000.0 + x = np.arange(96).astype(np.float32).reshape(2, 3, 4, 4) + expect = x / np.sqrt(epsilon) + x = Tensor(x) + error = np.ones(shape=[2, 3, 4, 4]) * 1.0e-5 + + norm_op = Net(axis=axis, epsilon=epsilon) + output = norm_op(x) + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error)