add l2normalize ops for cpu
This commit is contained in:
parent
394a3fe379
commit
33870dc46d
|
@ -0,0 +1,154 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/l2_normalize_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
void L2NormalizeCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
epsilon_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon");
|
||||
axis_ = LongToInt(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "axis"));
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
CheckParam(kernel_node);
|
||||
if (axis_ < 0) {
|
||||
axis_ += SizeToInt(input_shape_.size());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void L2NormalizeCPUKernel<T>::CalcDenominator(const T *input_addr, const size_t reduce_size, const int dims,
|
||||
std::unique_ptr<T[]> *denominator_addr) {
|
||||
T temp = (T)0.0;
|
||||
T epsilon = (T)epsilon_;
|
||||
T denominator = (T)0.0;
|
||||
// Calculate transpose axes and stride
|
||||
size_t stride = 1;
|
||||
std::vector<size_t> axes(input_shape_.size());
|
||||
int k = 0;
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
if (i != axis_) {
|
||||
axes[k] = i;
|
||||
++k;
|
||||
} else {
|
||||
stride *= input_shape_[i];
|
||||
}
|
||||
}
|
||||
axes[k] = axis_;
|
||||
|
||||
std::vector<size_t> transpose_shape(input_shape_.size());
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
transpose_shape[i] = input_shape_[axes[i]];
|
||||
}
|
||||
|
||||
TransposeIterator tran_base_iter(std::move(transpose_shape), std::move(axes), input_shape_);
|
||||
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
auto iter = tran_base_iter;
|
||||
iter.SetPos(start * stride);
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
denominator = input_addr[iter.GetPos()];
|
||||
denominator = denominator * denominator;
|
||||
iter.GenNextPos();
|
||||
for (size_t j = 1; j < stride; ++j) {
|
||||
temp = input_addr[iter.GetPos()];
|
||||
denominator += temp * temp;
|
||||
iter.GenNextPos();
|
||||
}
|
||||
denominator = (denominator > epsilon) ? denominator : epsilon;
|
||||
(*denominator_addr)[i] = sqrt(denominator);
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, reduce_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void L2NormalizeCPUKernel<T>::CalcOutput(const T *input_addr, const std::vector<size_t> reduce_shape,
|
||||
const size_t output_size, T *output_addr,
|
||||
std::unique_ptr<T[]> const &denominator_addr) {
|
||||
BroadcastIterator broad_base_iter(input_shape_, reduce_shape, output_shape_);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
auto iter = broad_base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
T dividend = input_addr[iter.GetInputPosA()];
|
||||
T divisor = denominator_addr[iter.GetInputPosB()];
|
||||
if (divisor == (T)0) {
|
||||
if (dividend == (T)0) {
|
||||
output_addr[i] = std::numeric_limits<T>::quiet_NaN();
|
||||
continue;
|
||||
}
|
||||
if (std::numeric_limits<T>::has_infinity) {
|
||||
output_addr[i] = dividend > (T)0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
|
||||
} else {
|
||||
output_addr[i] = dividend > (T)0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
output_addr[i] = dividend / divisor;
|
||||
iter.GenNextPos();
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, output_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool L2NormalizeCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
int dims = input_shape_.size();
|
||||
std::vector<size_t> reduce_shape = input_shape_;
|
||||
size_t reduce_size = 1;
|
||||
reduce_shape[axis_] = 1;
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
reduce_size *= reduce_shape[i];
|
||||
}
|
||||
auto denominator_addr = std::make_unique<T[]>(reduce_size);
|
||||
|
||||
L2NormalizeCPUKernel<T>::CalcDenominator(input_addr, reduce_size, dims, &denominator_addr);
|
||||
|
||||
size_t output_size = outputs[0]->size / sizeof(T);
|
||||
L2NormalizeCPUKernel<T>::CalcOutput(input_addr, reduce_shape, output_size, output_addr, denominator_addr);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void L2NormalizeCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
int dims = SizeToInt(input_shape_.size());
|
||||
if (input_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but L2NormalizeCPUKernel needs 1 input.";
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but L2NormalizeCPUKernel needs 1 output.";
|
||||
}
|
||||
if (axis_ < -dims || axis_ >= dims) {
|
||||
MS_LOG(EXCEPTION) << "Attr axis_ " << axis_ << " must be in " << -dims << "~" << dims;
|
||||
}
|
||||
if (epsilon_ == 0.0) {
|
||||
MS_LOG(EXCEPTION) << "Attr epsilon can not be zero.";
|
||||
}
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2_NORMALIZE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2_NORMALIZE_CPU_KERNEL_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class L2NormalizeCPUKernel : public CPUKernel {
|
||||
public:
|
||||
L2NormalizeCPUKernel() = default;
|
||||
~L2NormalizeCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
void CalcDenominator(const T *input_addr, const size_t reduce_size, const int dims,
|
||||
std::unique_ptr<T[]> *denominator_addr);
|
||||
|
||||
void CalcOutput(const T *input_addr, const std::vector<size_t> reduce_shape, const size_t output_size, T *output_addr,
|
||||
std::unique_ptr<T[]> const &denominator_addr);
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
float epsilon_;
|
||||
int axis_;
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(L2Normalize, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
L2NormalizeCPUKernel, float16);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(L2Normalize, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
L2NormalizeCPUKernel, float);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2_NORMALIZE_CPU_KERNEL_H_
|
|
@ -3036,7 +3036,7 @@ class L2Normalize(PrimitiveWithInfer):
|
|||
TypeError: If dtype of `input_x` is neither float16 nor float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> l2_normalize = ops.L2Normalize()
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, axis=0, epsilon=1e-4):
|
||||
super(Net, self).__init__()
|
||||
self.norm = P.L2Normalize(axis=axis, epsilon=epsilon)
|
||||
|
||||
def construct(self, x):
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_l2normalize_float32():
|
||||
x = np.arange(96).astype(np.float32).reshape(2, 3, 4, 4)
|
||||
expect = x / np.sqrt(np.sum(x**2, axis=0, keepdims=True))
|
||||
x = Tensor(x)
|
||||
error = np.ones(shape=[2, 3, 4, 4]) * 1.0e-5
|
||||
|
||||
norm_op = Net(axis=0)
|
||||
output = norm_op(x)
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_l2normalize_float16():
|
||||
x = np.arange(96).astype(np.float16).reshape(2, 3, 4, 4)
|
||||
expect = x / np.sqrt(np.sum(x**2, axis=0, keepdims=True))
|
||||
x = Tensor(x, dtype=mstype.float16)
|
||||
error = np.ones(shape=[2, 3, 4, 4]) * 1.0e-3
|
||||
|
||||
norm_op = Net(axis=0)
|
||||
output = norm_op(x)
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_l2normalize_axis():
|
||||
axis = -2
|
||||
x = np.arange(96).astype(np.float32).reshape(2, 3, 4, 4)
|
||||
expect = x / np.sqrt(np.sum(x**2, axis=axis, keepdims=True))
|
||||
x = Tensor(x)
|
||||
error = np.ones(shape=[2, 3, 4, 4]) * 1.0e-5
|
||||
|
||||
norm_op = Net(axis=axis)
|
||||
output = norm_op(x)
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_l2normalize_epsilon():
|
||||
axis = -1
|
||||
epsilon = 900000.0
|
||||
x = np.arange(96).astype(np.float32).reshape(2, 3, 4, 4)
|
||||
expect = x / np.sqrt(epsilon)
|
||||
x = Tensor(x)
|
||||
error = np.ones(shape=[2, 3, 4, 4]) * 1.0e-5
|
||||
|
||||
norm_op = Net(axis=axis, epsilon=epsilon)
|
||||
output = norm_op(x)
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
Loading…
Reference in New Issue