forked from mindspore-Ecosystem/mindspore
switch implementation of BiasAdd op to nnacl
This commit is contained in:
parent
14cf33a6df
commit
8302820597
|
@ -16,28 +16,44 @@
|
|||
|
||||
#include "backend/kernel_compiler/cpu/bias_add_cpu_kernel.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kBiasAddMinDim = 2;
|
||||
constexpr size_t kBiasAddMaxDim = 5;
|
||||
constexpr size_t kBiasAddInputNum = 2;
|
||||
|
||||
void BiasAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
bias_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
data_shape_ = input_shape_.size();
|
||||
if (input_shape_.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Input tensor's rank must be at least 2 for 'BiasAdd' Op, but input tensor's rank is "
|
||||
<< input_shape_.size();
|
||||
bias_param_.ndim_ = input_shape_.size();
|
||||
if (bias_param_.ndim_ < kBiasAddMinDim || bias_param_.ndim_ > kBiasAddMaxDim) {
|
||||
MS_LOG(EXCEPTION) << "Input tensor's rank must be in closed interval [2,5] for 'BiasAdd' Op,"
|
||||
"but input tensor's rank is "
|
||||
<< bias_param_.ndim_;
|
||||
}
|
||||
if (bias_shape_.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Bias's rank must be 1 for 'BiasAdd' Op, but bias' rank is" << bias_shape_.size();
|
||||
}
|
||||
if (input_shape_[1] != bias_shape_[0]) {
|
||||
MS_LOG(EXCEPTION) << "Bias shape not match, bias shape must be equal to C channel's shape";
|
||||
if (input_shape_[bias_param_.ndim_ - 1] != bias_shape_[0]) {
|
||||
MS_LOG(EXCEPTION) << "Bias shape [" << bias_shape_[0] << "] not match, it must equal C channel's shape:["
|
||||
<< input_shape_[bias_param_.ndim_ - 1] << "]";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < bias_param_.ndim_; ++i) {
|
||||
bias_param_.in_shape0_[i] = input_shape_[i];
|
||||
bias_param_.in_shape1_[i] = 1;
|
||||
bias_param_.out_shape_[i] = input_shape_[i];
|
||||
}
|
||||
|
||||
bias_param_.in_shape1_[bias_param_.ndim_ - 1] = input_shape_[bias_param_.ndim_ - 1];
|
||||
}
|
||||
|
||||
bool BiasAddCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
if (inputs.size() != 2 || outputs.size() != 1) {
|
||||
if (inputs.size() != kBiasAddInputNum || outputs.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "inputs outputs size not supoort";
|
||||
}
|
||||
|
||||
|
@ -45,46 +61,15 @@ bool BiasAddCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
auto bias_addr = reinterpret_cast<float *>(inputs[1]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
|
||||
if (input_shape_.size() > 2) {
|
||||
size_t hw_size = 1;
|
||||
for (size_t i = 2; i < input_shape_.size(); ++i) {
|
||||
hw_size *= input_shape_[i];
|
||||
}
|
||||
size_t data_num = std::accumulate(input_shape_.begin(), input_shape_.end(), 1LL, std::multiplies<int>());
|
||||
|
||||
size_t c_size = input_shape_[1];
|
||||
for (size_t n = 0; n < input_shape_[0]; ++n) {
|
||||
for (size_t c = 0; c < c_size; ++c) {
|
||||
size_t offset = n * c_size * hw_size + c * hw_size;
|
||||
size_t hw = 0;
|
||||
#ifdef ENABLE_AVX
|
||||
constexpr size_t C8NUM = 8;
|
||||
size_t hw8 = hw_size / C8NUM * C8NUM;
|
||||
const float *in_ptr = src_addr + offset;
|
||||
float *out_ptr = output_addr + offset;
|
||||
for (; hw < hw8; hw += C8NUM) {
|
||||
__m256 src_r1 = _mm256_loadu_ps(in_ptr);
|
||||
__m256 bias_r2 = _mm256_set1_ps(bias_addr[c]);
|
||||
__m256 dst_r3 = _mm256_add_ps(src_r1, bias_r2);
|
||||
_mm256_storeu_ps(out_ptr, dst_r3);
|
||||
std::vector<float> buffer_in(data_num, 0);
|
||||
std::vector<float> buffer_bias(data_num, 0);
|
||||
float *tile_in = &buffer_in.at(0);
|
||||
float *tile_bias = &buffer_bias.at(0);
|
||||
|
||||
in_ptr += C8NUM;
|
||||
out_ptr += C8NUM;
|
||||
}
|
||||
#endif
|
||||
for (; hw < hw_size; ++hw) {
|
||||
output_addr[offset + hw] = src_addr[offset + hw] + bias_addr[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
size_t n_offset = 0;
|
||||
for (size_t n = 0; n < input_shape_[0]; ++n) {
|
||||
for (size_t c = 0; c < input_shape_[1]; ++c) {
|
||||
output_addr[n_offset + c] = src_addr[n_offset + c] + bias_addr[c];
|
||||
}
|
||||
n_offset += input_shape_[1];
|
||||
}
|
||||
}
|
||||
// BroadcastAdd always returns NNACL_OK, so no need to check return val.
|
||||
(void)BroadcastAdd(src_addr, bias_addr, tile_in, tile_bias, output_addr, data_num, &bias_param_);
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "nnacl/fp32/arithmetic_fp32.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -33,9 +34,9 @@ class BiasAddCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
size_t data_shape_{0};
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> bias_shape_;
|
||||
ArithmeticParameter bias_param_;
|
||||
};
|
||||
MS_REG_CPU_KERNEL(BiasAdd, KernelAttr(), BiasAddCPUKernel);
|
||||
} // namespace kernel
|
||||
|
|
|
@ -20,7 +20,7 @@ bias_add_op_info = CpuRegOp("BiasAdd") \
|
|||
.input(0, "x", "required") \
|
||||
.input(1, "bias", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_ChannelLast, DataType.F32_Default, DataType.F32_ChannelLast) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -605,6 +605,7 @@ class DataType:
|
|||
BOOL_NHWC = ("bool", "NHWC")
|
||||
BOOL_HWCN = ("bool", "HWCN")
|
||||
BOOL_NDHWC = ("bool", "NDHWC")
|
||||
BOOL_ChannelLast = ("bool", "ChannelLast")
|
||||
|
||||
I8_None = ("int8", "")
|
||||
I8_Default = ("int8", "DefaultFormat")
|
||||
|
@ -616,6 +617,7 @@ class DataType:
|
|||
I8_NHWC = ("int8", "NHWC")
|
||||
I8_HWCN = ("int8", "HWCN")
|
||||
I8_NDHWC = ("int8", "NDHWC")
|
||||
I8_ChannelLast = ("int8", "ChannelLast")
|
||||
|
||||
U8_None = ("uint8", "")
|
||||
U8_Default = ("uint8", "DefaultFormat")
|
||||
|
@ -627,6 +629,7 @@ class DataType:
|
|||
U8_NHWC = ("uint8", "NHWC")
|
||||
U8_HWCN = ("uint8", "HWCN")
|
||||
U8_NDHWC = ("uint8", "NDHWC")
|
||||
U8_ChannelLast = ("uint8", "ChannelLast")
|
||||
|
||||
I16_None = ("int16", "")
|
||||
I16_Default = ("int16", "DefaultFormat")
|
||||
|
@ -638,6 +641,7 @@ class DataType:
|
|||
I16_NHWC = ("int16", "NHWC")
|
||||
I16_HWCN = ("int16", "HWCN")
|
||||
I16_NDHWC = ("int16", "NDHWC")
|
||||
I16_ChannelLast = ("int16", "ChannelLast")
|
||||
|
||||
U16_None = ("uint16", "")
|
||||
U16_Default = ("uint16", "DefaultFormat")
|
||||
|
@ -649,6 +653,7 @@ class DataType:
|
|||
U16_NHWC = ("uint16", "NHWC")
|
||||
U16_HWCN = ("uint16", "HWCN")
|
||||
U16_NDHWC = ("uint16", "NDHWC")
|
||||
U16_ChannelLast = ("uint16", "ChannelLast")
|
||||
|
||||
I32_None = ("int32", "")
|
||||
I32_Default = ("int32", "DefaultFormat")
|
||||
|
@ -660,6 +665,7 @@ class DataType:
|
|||
I32_NHWC = ("int32", "NHWC")
|
||||
I32_HWCN = ("int32", "HWCN")
|
||||
I32_NDHWC = ("int32", "NDHWC")
|
||||
I32_ChannelLast = ("int32", "ChannelLast")
|
||||
|
||||
U32_None = ("uint32", "")
|
||||
U32_Default = ("uint32", "DefaultFormat")
|
||||
|
@ -671,6 +677,7 @@ class DataType:
|
|||
U32_NHWC = ("uint32", "NHWC")
|
||||
U32_HWCN = ("uint32", "HWCN")
|
||||
U32_NDHWC = ("uint32", "NDHWC")
|
||||
U32_ChannelLast = ("uint32", "ChannelLast")
|
||||
|
||||
I64_None = ("int64", "")
|
||||
I64_Default = ("int64", "DefaultFormat")
|
||||
|
@ -682,6 +689,7 @@ class DataType:
|
|||
I64_NHWC = ("int64", "NHWC")
|
||||
I64_HWCN = ("int64", "HWCN")
|
||||
I64_NDHWC = ("int64", "NDHWC")
|
||||
I64_ChannelLast = ("int64", "ChannelLast")
|
||||
|
||||
U64_None = ("uint64", "")
|
||||
U64_Default = ("uint64", "DefaultFormat")
|
||||
|
@ -693,6 +701,7 @@ class DataType:
|
|||
U64_NHWC = ("uint64", "NHWC")
|
||||
U64_HWCN = ("uint64", "HWCN")
|
||||
U64_NDHWC = ("uint64", "NDHWC")
|
||||
U64_ChannelLast = ("uint64", "ChannelLast")
|
||||
|
||||
F16_None = ("float16", "")
|
||||
F16_Default = ("float16", "DefaultFormat")
|
||||
|
@ -709,6 +718,7 @@ class DataType:
|
|||
F16_NDC1HWC0 = ("float16", "NDC1HWC0")
|
||||
F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
|
||||
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
|
||||
F16_ChannelLast = ("float16", "ChannelLast")
|
||||
|
||||
F32_None = ("float32", "")
|
||||
F32_Default = ("float32", "DefaultFormat")
|
||||
|
@ -725,6 +735,7 @@ class DataType:
|
|||
F32_NDC1HWC0 = ("float32", "NDC1HWC0")
|
||||
F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
|
||||
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
|
||||
F32_ChannelLast = ("float32", "ChannelLast")
|
||||
|
||||
F64_None = ("float64", "")
|
||||
F64_Default = ("float64", "DefaultFormat")
|
||||
|
@ -736,3 +747,4 @@ class DataType:
|
|||
F64_NHWC = ("float64", "NHWC")
|
||||
F64_HWCN = ("float64", "HWCN")
|
||||
F64_NDHWC = ("float64", "NDHWC")
|
||||
F64_ChannelLast = ("float64", "ChannelLast")
|
||||
|
|
|
@ -36,11 +36,15 @@ class Net(nn.Cell):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bias_add4d():
|
||||
x = np.ones([2, 3, 4, 4]).astype(np.float32)
|
||||
b = np.array([1, 1, 1]).astype(np.float32)
|
||||
x_shape = [2, 3, 4, 5]
|
||||
x = np.ones(x_shape).astype(np.float32)
|
||||
b = np.array([0.3, 0.5, 0.7]).astype(np.float32)
|
||||
bias_add = Net()
|
||||
output = bias_add(Tensor(x), Tensor(b))
|
||||
expect_output = np.ones([2, 3, 4, 4]).astype(np.float32) * 2
|
||||
expect_output = x
|
||||
for i in range(x_shape[0]):
|
||||
for j in range(x_shape[1]):
|
||||
expect_output[i][j] = x[i][j] + b[j]
|
||||
print(output)
|
||||
assert np.all(output.asnumpy() == expect_output), "bias_add execute failed, please check current code commit"
|
||||
|
||||
|
@ -49,11 +53,15 @@ def test_bias_add4d():
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bias_add2d():
|
||||
x = np.ones([2, 3]).astype(np.float32)
|
||||
b = np.array([1, 1, 1]).astype(np.float32)
|
||||
x_shape = [2, 3]
|
||||
x = np.ones(x_shape).astype(np.float32)
|
||||
b = np.array([0.3, 0.5, 0.7]).astype(np.float32)
|
||||
bias_add = Net()
|
||||
output = bias_add(Tensor(x), Tensor(b))
|
||||
expect_output = np.ones([2, 3]).astype(np.float32) * 2
|
||||
expect_output = x
|
||||
for i in range(x_shape[0]):
|
||||
for j in range(x_shape[1]):
|
||||
expect_output[i][j] = x[i][j] + b[j]
|
||||
print(output)
|
||||
assert np.all(output.asnumpy() == expect_output), "bias_add execute failed, please check current code commit"
|
||||
|
||||
|
@ -62,11 +70,15 @@ def test_bias_add2d():
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bias_add3d():
|
||||
x = np.ones([2, 3, 4]).astype(np.float32)
|
||||
b = np.array([1, 1, 1]).astype(np.float32)
|
||||
x_shape = [2, 3, 4]
|
||||
x = np.ones(x_shape).astype(np.float32)
|
||||
b = np.array([0.3, 0.5, 0.7]).astype(np.float32)
|
||||
bias_add = Net()
|
||||
output = bias_add(Tensor(x), Tensor(b))
|
||||
expect_output = np.ones([2, 3, 4]).astype(np.float32) * 2
|
||||
expect_output = x
|
||||
for i in range(x_shape[0]):
|
||||
for j in range(x_shape[1]):
|
||||
expect_output[i][j] = x[i][j] + b[j]
|
||||
print(output)
|
||||
assert np.all(output.asnumpy() == expect_output), "bias_add execute failed, please check current code commit"
|
||||
|
||||
|
@ -74,10 +86,50 @@ def test_bias_add3d():
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bias_add5d():
|
||||
x = np.ones([2, 5, 4, 4, 4]).astype(np.float32)
|
||||
b = np.array([1, 1, 1, 1, 1]).astype(np.float32)
|
||||
x_shape = [2, 5, 2, 3, 4]
|
||||
x = np.ones(x_shape).astype(np.float32)
|
||||
b = np.array([0.1, 0.3, 0.5, 0.7, 0.9]).astype(np.float32)
|
||||
bias_add = Net()
|
||||
output = bias_add(Tensor(x), Tensor(b))
|
||||
expect_output = np.ones([2, 5, 4, 4, 4]).astype(np.float32) * 2
|
||||
expect_output = x
|
||||
for i in range(x_shape[0]):
|
||||
for j in range(x_shape[1]):
|
||||
expect_output[i][j] = x[i][j] + b[j]
|
||||
print(output)
|
||||
assert np.all(output.asnumpy() == expect_output), "bias_add execute failed, please check current code commit"
|
||||
|
||||
|
||||
class Net2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.mul = P.Mul()
|
||||
self.div = P.Div()
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x, y, z, w):
|
||||
mul_ = self.mul(x, y)
|
||||
div_ = self.div(z, w)
|
||||
temp = self.bias_add(mul_, div_)
|
||||
temp = self.bias_add(temp, div_)
|
||||
return self.add(temp, x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_net2():
|
||||
x_shape = [2, 3, 4]
|
||||
x = np.ones(x_shape).astype(np.float32)
|
||||
y = np.ones(x_shape).astype(np.float32)
|
||||
z = np.array([1.1, 2.2, 3.4]).astype(np.float32)
|
||||
w = np.array([10, 10, 10]).astype(np.float32)
|
||||
net2 = Net2()
|
||||
output = net2(Tensor(x), Tensor(y), Tensor(z), Tensor(w))
|
||||
expect_out = (np.array([[[2.22, 2.22, 2.22, 2.22],
|
||||
[2.44, 2.44, 2.44, 2.44],
|
||||
[2.68, 2.68, 2.68, 2.68]],
|
||||
[[2.22, 2.22, 2.22, 2.22],
|
||||
[2.44, 2.44, 2.44, 2.44],
|
||||
[2.68, 2.68, 2.68, 2.68]]]))
|
||||
assert np.allclose(output.asnumpy(), expect_out)
|
||||
|
|
Loading…
Reference in New Issue