forked from mindspore-Ecosystem/mindspore
!12922 biasadd report error when input 5d with nchw format at cpu
From: @wangyanling10 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
8b733dccaa
|
@ -3,6 +3,20 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/core)
|
|||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
include_directories(${CMAKE_BINARY_DIR})
|
||||
|
||||
if(ENABLE_CPU)
|
||||
if("${X86_64_SIMD}" STREQUAL "sse")
|
||||
add_compile_definitions(ENABLE_SSE)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.2")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2")
|
||||
endif()
|
||||
if("${X86_64_SIMD}" STREQUAL "avx")
|
||||
add_compile_definitions(ENABLE_SSE)
|
||||
add_compile_definitions(ENABLE_AVX)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(ENABLE_ACL)
|
||||
set(ASCEND_PATH /usr/local/Ascend)
|
||||
include_directories(${ASCEND_PATH}/acllib/include)
|
||||
|
@ -29,7 +43,7 @@ if(ENABLE_GPU)
|
|||
find_package(CUDA REQUIRED)
|
||||
find_package(Threads)
|
||||
if(${CUDA_VERSION} VERSION_LESS ${MS_REQUIRE_CUDA_VERSION})
|
||||
message(FATAL_ERROR "The minimum CUDA version ${MS_REQUIRE_CUDA_VERSION} is required, \
|
||||
message(FATAL_ERROR "The minimum CUDA version ${MS_REQUIRE_CUDA_VERSION} is required, \
|
||||
but only CUDA ${CUDA_VERSION} found.")
|
||||
endif()
|
||||
enable_language(CUDA)
|
||||
|
|
|
@ -22,21 +22,16 @@ void BiasAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
bias_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
if (input_shape_.size() == 4) {
|
||||
data_shape_ = 4;
|
||||
} else if (input_shape_.size() == 2) {
|
||||
data_shape_ = 2;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "bias add input data format should be NCHW or NC";
|
||||
}
|
||||
if (input_shape_.size() != 2 && input_shape_.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "bias add input shape nchw or nc";
|
||||
data_shape_ = input_shape_.size();
|
||||
if (input_shape_.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Input tensor's rank must be at least 2 for 'BiasAdd' Op, but input tensor's rank is "
|
||||
<< input_shape_.size();
|
||||
}
|
||||
if (bias_shape_.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "bias shape invalid";
|
||||
MS_LOG(EXCEPTION) << "Bias's rank must be 1 for 'BiasAdd' Op, but bias' rank is" << bias_shape_.size();
|
||||
}
|
||||
if (input_shape_[1] != bias_shape_[0]) {
|
||||
MS_LOG(EXCEPTION) << "bias shape not match";
|
||||
MS_LOG(EXCEPTION) << "Bias shape not match, bias shape must be equal to C channel's shape";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,22 +45,36 @@ bool BiasAddCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
auto bias_addr = reinterpret_cast<float *>(inputs[1]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
|
||||
if (data_shape_ == 4) {
|
||||
size_t h_size = input_shape_[3];
|
||||
size_t c_size = input_shape_[2] * h_size;
|
||||
size_t n_size = input_shape_[1] * c_size;
|
||||
size_t hw_size = input_shape_[2] * input_shape_[3];
|
||||
size_t n_offset = 0;
|
||||
if (input_shape_.size() > 2) {
|
||||
size_t hw_size = 1;
|
||||
for (size_t i = 2; i < input_shape_.size(); ++i) {
|
||||
hw_size *= input_shape_[i];
|
||||
}
|
||||
|
||||
size_t c_size = input_shape_[1];
|
||||
for (size_t n = 0; n < input_shape_[0]; ++n) {
|
||||
size_t c_offset = 0;
|
||||
for (size_t c = 0; c < input_shape_[1]; ++c) {
|
||||
for (size_t hw = 0; hw < hw_size; ++hw) {
|
||||
size_t offset = n_offset + c_offset + hw;
|
||||
output_addr[offset] = src_addr[offset] + bias_addr[c];
|
||||
for (size_t c = 0; c < c_size; ++c) {
|
||||
size_t offset = n * c_size * hw_size + c * hw_size;
|
||||
size_t hw = 0;
|
||||
#ifdef ENABLE_AVX
|
||||
constexpr size_t C8NUM = 8;
|
||||
size_t hw8 = hw_size / C8NUM * C8NUM;
|
||||
const float *in_ptr = src_addr + offset;
|
||||
float *out_ptr = output_addr + offset;
|
||||
for (; hw < hw8; hw += C8NUM) {
|
||||
__m256 src_r1 = _mm256_loadu_ps(in_ptr);
|
||||
__m256 bias_r2 = _mm256_set1_ps(bias_addr[c]);
|
||||
__m256 dst_r3 = _mm256_add_ps(src_r1, bias_r2);
|
||||
_mm256_storeu_ps(out_ptr, dst_r3);
|
||||
|
||||
in_ptr += C8NUM;
|
||||
out_ptr += C8NUM;
|
||||
}
|
||||
#endif
|
||||
for (; hw < hw_size; ++hw) {
|
||||
output_addr[offset + hw] = src_addr[offset + hw] + bias_addr[c];
|
||||
}
|
||||
c_offset += c_size;
|
||||
}
|
||||
n_offset += n_size;
|
||||
}
|
||||
} else {
|
||||
size_t n_offset = 0;
|
||||
|
|
|
@ -33,7 +33,7 @@ class BiasAddCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
uint8_t data_shape_{0};
|
||||
size_t data_shape_{0};
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> bias_shape_;
|
||||
};
|
||||
|
|
|
@ -21,8 +21,9 @@ namespace kernel {
|
|||
void BiasAddGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (input_shape_.size() != 4 && input_shape_.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "input data format not support";
|
||||
if (input_shape_.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Input tensor's rank must be at least 2 for 'BiasAddGrad' Op, but input tensor's rank is "
|
||||
<< input_shape_.size();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,23 +35,21 @@ bool BiasAddGradCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const s
|
|||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
|
||||
if (input_shape_.size() == 4) {
|
||||
size_t h_size = input_shape_[3];
|
||||
size_t c_size = h_size * input_shape_[2];
|
||||
size_t n_size = c_size * input_shape_[1];
|
||||
size_t hw_size = input_shape_[2] * input_shape_[3];
|
||||
size_t c_offset = 0;
|
||||
for (size_t c = 0; c < input_shape_[1]; ++c) {
|
||||
if (input_shape_.size() > 2) {
|
||||
size_t hw_size = 1;
|
||||
for (size_t i = 2; i < input_shape_.size(); ++i) {
|
||||
hw_size *= input_shape_[i];
|
||||
}
|
||||
|
||||
size_t c_size = input_shape_[1];
|
||||
for (size_t c = 0; c < c_size; ++c) {
|
||||
output_addr[c] = 0;
|
||||
size_t n_offset = 0;
|
||||
for (size_t n = 0; n < input_shape_[0]; ++n) {
|
||||
size_t offset = n * c_size * hw_size + c * hw_size;
|
||||
for (size_t hw = 0; hw < hw_size; ++hw) {
|
||||
size_t offset = c_offset + n_offset + hw;
|
||||
output_addr[c] += input_addr[offset];
|
||||
output_addr[c] += input_addr[offset + hw];
|
||||
}
|
||||
n_offset += n_size;
|
||||
}
|
||||
c_offset += c_size;
|
||||
}
|
||||
} else if (input_shape_.size() == 2) {
|
||||
for (size_t c = 0; c < input_shape_[1]; ++c) {
|
||||
|
|
|
@ -35,7 +35,7 @@ class Net(nn.Cell):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bias_add1():
|
||||
def test_bias_add4d():
|
||||
x = np.ones([2, 3, 4, 4]).astype(np.float32)
|
||||
b = np.array([1, 1, 1]).astype(np.float32)
|
||||
bias_add = Net()
|
||||
|
@ -48,7 +48,7 @@ def test_bias_add1():
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bias_add2():
|
||||
def test_bias_add2d():
|
||||
x = np.ones([2, 3]).astype(np.float32)
|
||||
b = np.array([1, 1, 1]).astype(np.float32)
|
||||
bias_add = Net()
|
||||
|
@ -56,3 +56,52 @@ def test_bias_add2():
|
|||
expect_output = np.ones([2, 3]).astype(np.float32) * 2
|
||||
print(output)
|
||||
assert np.all(output.asnumpy() == expect_output), "bias_add execute failed, please check current code commit"
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bias_add3d():
|
||||
x = np.ones([2, 3, 4]).astype(np.float32)
|
||||
b = np.array([1, 1, 1]).astype(np.float32)
|
||||
bias_add = Net()
|
||||
output = bias_add(Tensor(x), Tensor(b))
|
||||
expect_output = np.ones([2, 3, 4]).astype(np.float32) * 2
|
||||
print(output)
|
||||
assert np.all(output.asnumpy() == expect_output), "bias_add execute failed, please check current code commit"
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bias_add5d():
|
||||
x = np.ones([2, 5, 4, 4, 4]).astype(np.float32)
|
||||
b = np.array([1, 1, 1, 1, 1]).astype(np.float32)
|
||||
bias_add = Net()
|
||||
output = bias_add(Tensor(x), Tensor(b))
|
||||
expect_output = np.ones([2, 5, 4, 4, 4]).astype(np.float32) * 2
|
||||
print(output)
|
||||
assert np.all(output.asnumpy() == expect_output), "bias_add execute failed, please check current code commit"
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bias_add6d():
|
||||
x = np.ones([2, 4, 4, 4, 4, 1]).astype(np.float32)
|
||||
b = np.array([1, 1, 1, 1]).astype(np.float32)
|
||||
bias_add = Net()
|
||||
output = bias_add(Tensor(x), Tensor(b))
|
||||
expect_output = np.ones([2, 4, 4, 4, 4, 1]).astype(np.float32) * 2
|
||||
print(output)
|
||||
assert np.all(output.asnumpy() == expect_output), "bias_add execute failed, please check current code commit"
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bias_add7d():
|
||||
x = np.ones([2, 4, 4, 4, 4, 1, 2]).astype(np.float32)
|
||||
b = np.array([1, 1, 1, 1]).astype(np.float32)
|
||||
bias_add = Net()
|
||||
output = bias_add(Tensor(x), Tensor(b))
|
||||
expect_output = np.ones([2, 4, 4, 4, 4, 1, 2]).astype(np.float32) * 2
|
||||
print(output)
|
||||
assert np.all(output.asnumpy() == expect_output), "bias_add execute failed, please check current code commit"
|
||||
|
|
|
@ -35,7 +35,7 @@ class Net(nn.Cell):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bias_add_grad1():
|
||||
def test_bias_add_grad2d():
|
||||
dout = np.ones([2, 3]).astype(np.float32)
|
||||
bias_add_grad = Net()
|
||||
output = bias_add_grad(Tensor(dout))
|
||||
|
@ -47,10 +47,32 @@ def test_bias_add_grad1():
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bias_add_grad2():
|
||||
def test_bias_add_grad4d():
|
||||
dout = np.ones([2, 3, 4, 4]).astype(np.float32)
|
||||
bias_add_grad = Net()
|
||||
output = bias_add_grad(Tensor(dout))
|
||||
expect_output = np.array([32., 32., 32.]).astype(np.float32)
|
||||
print(output.asnumpy())
|
||||
assert np.all(output.asnumpy() == expect_output), "bias_add_grad execute failed, please check current code commit"
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bias_add_grad5d():
|
||||
dout = np.ones([2, 3, 4, 4, 2]).astype(np.float32)
|
||||
bias_add_grad = Net()
|
||||
output = bias_add_grad(Tensor(dout))
|
||||
expect_output = np.array([64., 64., 64.]).astype(np.float32)
|
||||
print(output.asnumpy())
|
||||
assert np.all(output.asnumpy() == expect_output), "bias_add_grad execute failed, please check current code commit"
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bias_add_grad7d():
|
||||
dout = np.ones([2, 3, 4, 4, 2, 1, 10]).astype(np.float32)
|
||||
bias_add_grad = Net()
|
||||
output = bias_add_grad(Tensor(dout))
|
||||
expect_output = np.array([640., 640., 640.]).astype(np.float32)
|
||||
print(output.asnumpy())
|
||||
assert np.all(output.asnumpy() == expect_output), "bias_add_grad execute failed, please check current code commit"
|
||||
|
|
Loading…
Reference in New Issue