!16650 fix cpu transpose bug

From: @fan-jibin
Reviewed-by: @zhoufeng54,@zhaizhiqiang
Signed-off-by: @zhaizhiqiang
This commit is contained in:
mindspore-ci-bot 2021-05-21 15:10:22 +08:00 committed by Gitee
commit c6e6fcc97e
2 changed files with 7 additions and 9 deletions

View File

@ -35,7 +35,7 @@ void TransposeCPUFwdKernel::InitKernel(const CNodePtr &kernel_node) {
if (dtype_ == kTypeUnknown) {
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
}
if (axes_.size() > MAX_SHAPE_SIZE) {
if (axes_.size() > MAX_TRANSPOSE_DIM_SIZE) {
MS_LOG(EXCEPTION) << "Transpose support max dimension is " << MAX_SHAPE_SIZE << "D, but got " << axes_.size()
<< "D.";
}
@ -90,7 +90,7 @@ void TransposeCPUFwdKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
output_shape[i] = SizeToInt(output_shape_[i]);
}
if (axes_.size() <= MAX_TRANSPOSE_DIM_SIZE) {
if (axes_.size() <= DIMENSION_6D) {
int res = NNACL_OK;
if constexpr (std::is_same_v<T, int8_t>) {
res = DoTransposeInt8(input_addr, output_addr, output_shape, &transpose_param_);
@ -113,10 +113,8 @@ void TransposeCPUFwdKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
} else if constexpr (std::is_same_v<T, bool>) {
res = DoTransposeBool(input_addr, output_addr, output_shape, &transpose_param_);
}
if (res == NNACL_ERR) {
MS_LOG(EXCEPTION) << "Transpose input addr or output addr is null";
} else if (res == NNACL_PARAM_INVALID) {
MS_LOG(EXCEPTION) << "Transpose parameters are invalid.";
if (res != NNACL_OK) {
MS_LOG(ERROR) << "Transpose run failed";
}
} else {
size_t data_count = (inputs[0]->size) / sizeof(T);

View File

@ -1284,7 +1284,7 @@ def onp_kron(x, y):
return onp.kron(x, y)
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@ -1428,7 +1428,7 @@ def onp_diff(input_array):
return a, b, c, d, e, f, g
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@ -1770,7 +1770,7 @@ def test_cov():
match_all_arrays(mnp_res, onp_res, error=1e-5)
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training