develop int32 AddN op for cpu

This commit is contained in:
buxue 2021-05-18 17:51:44 +08:00
parent b82956cefc
commit b5ec25cc49
3 changed files with 75 additions and 42 deletions

View File

@ -17,14 +17,25 @@
#include "backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "backend/kernel_compiler/cpu/nnacl/fp32/add_fp32.h"
#include "backend/kernel_compiler/cpu/nnacl/errorcode.h"
#include "utils/ms_utils.h"
#include "common/thread_pool.h"
namespace mindspore {
namespace kernel {
void AddInt(const int *in_0, const int *in_1, int *out, int start, int end) {
int ret = ElementAddInt(in_0 + start, in_1 + start, out + start, end - start);
if (ret != NNACL_OK) {
MS_LOG(EXCEPTION) << "Add failed.";
}
}
void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
input_num_ = AnfAlgo::GetInputTensorNum(kernel_node);
CheckParam(kernel_node);
input_num_ = AnfAlgo::GetInputTensorNum(kernel_node);
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
@ -42,15 +53,31 @@ void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) {
bool AddNCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive();
for (size_t index = 2; index < input_num_; ++index) {
SetArgumentHandle(DNNL_ARG_SRC_0, outputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[index]->addr);
if (dtype_ == kNumberTypeFloat32) {
SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive();
for (size_t index = 2; index < input_num_; ++index) {
SetArgumentHandle(DNNL_ARG_SRC_0, outputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[index]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive();
}
} else if (dtype_ == kNumberTypeInt32) {
size_t elements_num = outputs[0]->size / sizeof(int);
const auto input_0 = reinterpret_cast<int *>(inputs[0]->addr);
const auto input_1 = reinterpret_cast<int *>(inputs[1]->addr);
auto output = reinterpret_cast<int *>(outputs[0]->addr);
auto task_0 = std::bind(AddInt, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2);
CPUKernelUtils::ParallelFor(task_0, elements_num);
for (size_t index = 2; index < input_num_; ++index) {
const auto input = reinterpret_cast<int *>(inputs[index]->addr);
auto task = std::bind(AddInt, input, output, output, std::placeholders::_1, std::placeholders::_2);
CPUKernelUtils::ParallelFor(task, elements_num);
}
} else {
MS_LOG(EXCEPTION) << "AddN only support float32 and int32, but got " << TypeIdToType(dtype_)->ToString();
}
return true;
}

View File

@ -24,7 +24,7 @@ namespace mindspore {
namespace kernel {
class AddNCPUKernel : public MKLCPUKernel {
public:
AddNCPUKernel() : input_num_(0) {}
AddNCPUKernel() = default;
~AddNCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
@ -34,13 +34,17 @@ class AddNCPUKernel : public MKLCPUKernel {
private:
void CheckParam(const CNodePtr &kernel_node);
size_t input_num_;
size_t input_num_{0};
std::vector<size_t> output_shape_;
TypeId dtype_{kNumberTypeFloat32};
};
MS_REG_CPU_KERNEL(AddN,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
AddNCPUKernel);
MS_REG_CPU_KERNEL(AddN,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
AddNCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -19,60 +19,62 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class Net2I(nn.Cell):
class Net2Inputs(nn.Cell):
def __init__(self):
super(Net2I, self).__init__()
super(Net2Inputs, self).__init__()
self.addn = P.AddN()
def construct(self, x, y):
return self.addn((x, y))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net_2Input():
x = np.arange(2 * 3 * 2).reshape(2, 3, 2).astype(np.float32)
y = np.arange(2 * 3 * 2).reshape(2, 3, 2).astype(np.float32)
addn = Net2I()
output = addn(Tensor(x, mstype.float32), Tensor(y, mstype.float32))
print("output:\n", output)
expect_result = [[[0., 2.],
[4., 6.],
[8., 10.]],
[[12., 14.],
[16., 18.],
[20., 22.]]]
def test_two_tensors_add():
x = np.arange(2 * 3 * 2).reshape((2, 3, 2))
y = np.arange(88, 2 * 3 * 2 + 88).reshape((2, 3, 2))
addn_net = Net2Inputs()
dtypes = (np.int32, np.float32)
for dtype in dtypes:
output = addn_net(Tensor(x.astype(dtype)), Tensor(y.astype(dtype)))
expect_result = (x + y).astype(dtype)
assert output.asnumpy().dtype == expect_result.dtype
assert np.array_equal(output.asnumpy(), expect_result)
assert (output.asnumpy() == expect_result).all()
class Net3I(nn.Cell):
class Net4Inputs(nn.Cell):
def __init__(self):
super(Net3I, self).__init__()
super(Net4Inputs, self).__init__()
self.addn = P.AddN()
def construct(self, x, y, z):
return self.addn((x, y, z))
def construct(self, x, y, m, n):
return self.addn((x, y, m, n))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net_3Input():
x = np.arange(2 * 3).reshape(2, 3).astype(np.float32)
y = np.arange(2 * 3).reshape(2, 3).astype(np.float32)
z = np.arange(2 * 3).reshape(2, 3).astype(np.float32)
addn = Net3I()
output = addn(Tensor(x, mstype.float32), Tensor(y, mstype.float32), Tensor(z, mstype.float32))
print("output:\n", output)
expect_result = [[0., 3., 6.],
[9., 12., 15]]
def test_four_tensors_add():
x = np.arange(2 * 3).reshape((2, 3))
y = np.arange(1, 2 * 3 + 1).reshape((2, 3))
m = np.arange(2, 2 * 3 + 2).reshape((2, 3))
n = np.arange(3, 2 * 3 + 3).reshape((2, 3))
addn_net = Net4Inputs()
dtypes = (np.int32, np.float32)
for dtype in dtypes:
output = addn_net(Tensor(x.astype(dtype)), Tensor(y.astype(dtype)),
Tensor(m.astype(dtype)), Tensor(n.astype(dtype)))
expect_result = (x + y + m + n).astype(dtype)
assert output.asnumpy().dtype == expect_result.dtype
assert np.array_equal(output.asnumpy(), expect_result)
assert (output.asnumpy() == expect_result).all()
if __name__ == '__main__':
test_net_2Input()
test_net_3Input()
test_two_tensors_add()
test_four_tensors_add()