forked from mindspore-Ecosystem/mindspore
develop int32 AddN op for cpu
This commit is contained in:
parent
b82956cefc
commit
b5ec25cc49
|
@ -17,14 +17,25 @@
|
|||
#include "backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "backend/kernel_compiler/cpu/nnacl/fp32/add_fp32.h"
|
||||
#include "backend/kernel_compiler/cpu/nnacl/errorcode.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "common/thread_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void AddInt(const int *in_0, const int *in_1, int *out, int start, int end) {
|
||||
int ret = ElementAddInt(in_0 + start, in_1 + start, out + start, end - start);
|
||||
if (ret != NNACL_OK) {
|
||||
MS_LOG(EXCEPTION) << "Add failed.";
|
||||
}
|
||||
}
|
||||
|
||||
void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
input_num_ = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
CheckParam(kernel_node);
|
||||
input_num_ = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
|
@ -42,15 +53,31 @@ void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
bool AddNCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
|
||||
ExecutePrimitive();
|
||||
for (size_t index = 2; index < input_num_; ++index) {
|
||||
SetArgumentHandle(DNNL_ARG_SRC_0, outputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[index]->addr);
|
||||
if (dtype_ == kNumberTypeFloat32) {
|
||||
SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
|
||||
ExecutePrimitive();
|
||||
for (size_t index = 2; index < input_num_; ++index) {
|
||||
SetArgumentHandle(DNNL_ARG_SRC_0, outputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[index]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
|
||||
ExecutePrimitive();
|
||||
}
|
||||
} else if (dtype_ == kNumberTypeInt32) {
|
||||
size_t elements_num = outputs[0]->size / sizeof(int);
|
||||
const auto input_0 = reinterpret_cast<int *>(inputs[0]->addr);
|
||||
const auto input_1 = reinterpret_cast<int *>(inputs[1]->addr);
|
||||
auto output = reinterpret_cast<int *>(outputs[0]->addr);
|
||||
auto task_0 = std::bind(AddInt, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2);
|
||||
CPUKernelUtils::ParallelFor(task_0, elements_num);
|
||||
for (size_t index = 2; index < input_num_; ++index) {
|
||||
const auto input = reinterpret_cast<int *>(inputs[index]->addr);
|
||||
auto task = std::bind(AddInt, input, output, output, std::placeholders::_1, std::placeholders::_2);
|
||||
CPUKernelUtils::ParallelFor(task, elements_num);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "AddN only support float32 and int32, but got " << TypeIdToType(dtype_)->ToString();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
class AddNCPUKernel : public MKLCPUKernel {
|
||||
public:
|
||||
AddNCPUKernel() : input_num_(0) {}
|
||||
AddNCPUKernel() = default;
|
||||
~AddNCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
@ -34,13 +34,17 @@ class AddNCPUKernel : public MKLCPUKernel {
|
|||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
size_t input_num_;
|
||||
size_t input_num_{0};
|
||||
std::vector<size_t> output_shape_;
|
||||
TypeId dtype_{kNumberTypeFloat32};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(AddN,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
AddNCPUKernel);
|
||||
MS_REG_CPU_KERNEL(AddN,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
AddNCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -19,60 +19,62 @@ import pytest
|
|||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
class Net2I(nn.Cell):
|
||||
|
||||
class Net2Inputs(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2I, self).__init__()
|
||||
super(Net2Inputs, self).__init__()
|
||||
self.addn = P.AddN()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.addn((x, y))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_2Input():
|
||||
x = np.arange(2 * 3 * 2).reshape(2, 3, 2).astype(np.float32)
|
||||
y = np.arange(2 * 3 * 2).reshape(2, 3, 2).astype(np.float32)
|
||||
addn = Net2I()
|
||||
output = addn(Tensor(x, mstype.float32), Tensor(y, mstype.float32))
|
||||
print("output:\n", output)
|
||||
expect_result = [[[0., 2.],
|
||||
[4., 6.],
|
||||
[8., 10.]],
|
||||
[[12., 14.],
|
||||
[16., 18.],
|
||||
[20., 22.]]]
|
||||
def test_two_tensors_add():
|
||||
x = np.arange(2 * 3 * 2).reshape((2, 3, 2))
|
||||
y = np.arange(88, 2 * 3 * 2 + 88).reshape((2, 3, 2))
|
||||
addn_net = Net2Inputs()
|
||||
dtypes = (np.int32, np.float32)
|
||||
for dtype in dtypes:
|
||||
output = addn_net(Tensor(x.astype(dtype)), Tensor(y.astype(dtype)))
|
||||
expect_result = (x + y).astype(dtype)
|
||||
assert output.asnumpy().dtype == expect_result.dtype
|
||||
assert np.array_equal(output.asnumpy(), expect_result)
|
||||
|
||||
assert (output.asnumpy() == expect_result).all()
|
||||
|
||||
class Net3I(nn.Cell):
|
||||
class Net4Inputs(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net3I, self).__init__()
|
||||
super(Net4Inputs, self).__init__()
|
||||
self.addn = P.AddN()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
return self.addn((x, y, z))
|
||||
def construct(self, x, y, m, n):
|
||||
return self.addn((x, y, m, n))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_3Input():
|
||||
x = np.arange(2 * 3).reshape(2, 3).astype(np.float32)
|
||||
y = np.arange(2 * 3).reshape(2, 3).astype(np.float32)
|
||||
z = np.arange(2 * 3).reshape(2, 3).astype(np.float32)
|
||||
addn = Net3I()
|
||||
output = addn(Tensor(x, mstype.float32), Tensor(y, mstype.float32), Tensor(z, mstype.float32))
|
||||
print("output:\n", output)
|
||||
expect_result = [[0., 3., 6.],
|
||||
[9., 12., 15]]
|
||||
def test_four_tensors_add():
|
||||
x = np.arange(2 * 3).reshape((2, 3))
|
||||
y = np.arange(1, 2 * 3 + 1).reshape((2, 3))
|
||||
m = np.arange(2, 2 * 3 + 2).reshape((2, 3))
|
||||
n = np.arange(3, 2 * 3 + 3).reshape((2, 3))
|
||||
addn_net = Net4Inputs()
|
||||
dtypes = (np.int32, np.float32)
|
||||
for dtype in dtypes:
|
||||
output = addn_net(Tensor(x.astype(dtype)), Tensor(y.astype(dtype)),
|
||||
Tensor(m.astype(dtype)), Tensor(n.astype(dtype)))
|
||||
expect_result = (x + y + m + n).astype(dtype)
|
||||
assert output.asnumpy().dtype == expect_result.dtype
|
||||
assert np.array_equal(output.asnumpy(), expect_result)
|
||||
|
||||
assert (output.asnumpy() == expect_result).all()
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_net_2Input()
|
||||
test_net_3Input()
|
||||
test_two_tensors_add()
|
||||
test_four_tensors_add()
|
||||
|
|
Loading…
Reference in New Issue