add CPU l2loss op

This commit is contained in:
fanjibin 2021-07-20 21:44:39 +08:00 committed by fan-jibin
parent ed5fa7ba73
commit 9e5618a5b8
6 changed files with 278 additions and 1 deletions

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/l2loss_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
template <typename T>
void L2LossCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
std::vector<size_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (const size_t &d : x_shape) {
tensor_size_ *= d;
}
}
template <typename T>
bool L2LossCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto result_addr = reinterpret_cast<T *>(outputs[0]->addr);
*result_addr = (T)0;
for (size_t i = 0; i < tensor_size_; i++) {
*result_addr += input_addr[i] * input_addr[i];
}
*result_addr = *result_addr / 2;
return true;
}
template <typename T>
void L2LossCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but L2LossCPUKernel needs 1 input.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but L2LossCPUKernel needs 1 output.";
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2_LOSS_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2_LOSS_CPU_KERNEL_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class L2LossCPUKernel : public CPUKernel {
public:
L2LossCPUKernel() = default;
~L2LossCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
void CheckParam(const CNodePtr &kernel_node);
size_t tensor_size_{1};
};
MS_REG_CPU_KERNEL_T(L2Loss, KernelAttr(), L2LossCPUKernel, float16);
MS_REG_CPU_KERNEL_T(L2Loss, KernelAttr(), L2LossCPUKernel, float);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2_LOSS_CPU_KERNEL_H_

View File

@ -64,3 +64,4 @@ from .one_hot import _one_hot_cpu
from .pad import _pad_cpu from .pad import _pad_cpu
from .range import _range_cpu from .range import _range_cpu
from .tensor_copy_slices import _tensor_copy_slices_cpu from .tensor_copy_slices import _tensor_copy_slices_cpu
from .l2loss import _l2loss_cpu

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""L2Loss op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
l2loss_op_info = CpuRegOp("L2Loss") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(l2loss_op_info)
def _l2loss_cpu():
"""L2Loss cpu register"""
return

View File

@ -2679,7 +2679,7 @@ class L2Loss(PrimitiveWithInfer):
TypeError: If dtype of `input_x` is neither float16 nor float32. TypeError: If dtype of `input_x` is neither float16 nor float32.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``Ascend`` ``GPU`` ``CPU``
Examples Examples
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float16) >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float16)

View File

@ -0,0 +1,143 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore as ms
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
class L2LossNet(nn.Cell):
def __init__(self):
super(L2LossNet, self).__init__()
self.l2_loss = P.L2Loss()
def construct(self, x):
return self.l2_loss(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_l2loss_pynative_fp32_2x2():
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
error = 1e-4
x = Tensor(np.array([[1., 2.], [3., 4.]]), ms.float32)
expect = np.array(15, np.float32)
output = P.L2Loss()(x)
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_l2loss_pynative_fp16_2x2():
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
error = 1e-4
x = Tensor(np.array([[1., 2.], [3., 4.]]), ms.float16)
expect = np.array(15, np.float16)
output = P.L2Loss()(x)
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_l2loss_pynative_fp32_1x4():
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
error = 1e-4
x = Tensor(np.array([1., 2., 3., 4.]), ms.float32)
expect = np.array(15, np.float32)
output = P.L2Loss()(x)
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_l2loss_pynative_fp16_1x4():
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
error = 1e-4
x = Tensor(np.array([1., 2., 3., 4.]), ms.float16)
expect = np.array(15, np.float16)
output = P.L2Loss()(x)
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_l2loss_graph_fp32_1x4():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
error = 1e-4
x = Tensor(np.array([1., 2., 3., 4.]), ms.float32)
expect = np.array(15, np.float32)
l2_loss = L2LossNet()
output = l2_loss(x)
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_l2loss_graph_fp16_1x4():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
error = 1e-4
x = Tensor(np.array([1., 2., 3., 4.]), ms.float16)
expect = np.array(15, np.float16)
l2_loss = L2LossNet()
output = l2_loss(x)
diff = output.asnumpy() - expect
assert np.all(diff < error)
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_op = C.GradOperation(get_all=True)
def construct(self, x):
gradient_function = self.grad_op(self.net)
return gradient_function(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_l2loss_grad_fp32():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
x = Tensor(np.array([2.4, 3.2, 1.2, 5.9, 9.]).astype(np.float32))
error = 1e-4
net = L2LossNet()
output = GradNet(net)(x)[0]
expect = x
diff = output.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_l2loss_grad_fp16():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
x = Tensor(np.array([[2.4, 3.2, 4.8], [1.2, 5.9, 9.]]).astype(np.float16))
error = 1e-4
net = L2LossNet()
output = GradNet(net)(x)[0]
expect = x
diff = output.asnumpy() - expect
assert np.all(diff < error)