!16650 fix cpu transpose bug
From: @fan-jibin Reviewed-by: @zhoufeng54,@zhaizhiqiang Signed-off-by: @zhaizhiqiang
This commit is contained in:
commit
c6e6fcc97e
|
@ -35,7 +35,7 @@ void TransposeCPUFwdKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
if (dtype_ == kTypeUnknown) {
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
}
|
||||
if (axes_.size() > MAX_SHAPE_SIZE) {
|
||||
if (axes_.size() > MAX_TRANSPOSE_DIM_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "Transpose support max dimension is " << MAX_SHAPE_SIZE << "D, but got " << axes_.size()
|
||||
<< "D.";
|
||||
}
|
||||
|
@ -90,7 +90,7 @@ void TransposeCPUFwdKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
output_shape[i] = SizeToInt(output_shape_[i]);
|
||||
}
|
||||
|
||||
if (axes_.size() <= MAX_TRANSPOSE_DIM_SIZE) {
|
||||
if (axes_.size() <= DIMENSION_6D) {
|
||||
int res = NNACL_OK;
|
||||
if constexpr (std::is_same_v<T, int8_t>) {
|
||||
res = DoTransposeInt8(input_addr, output_addr, output_shape, &transpose_param_);
|
||||
|
@ -113,10 +113,8 @@ void TransposeCPUFwdKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
} else if constexpr (std::is_same_v<T, bool>) {
|
||||
res = DoTransposeBool(input_addr, output_addr, output_shape, &transpose_param_);
|
||||
}
|
||||
if (res == NNACL_ERR) {
|
||||
MS_LOG(EXCEPTION) << "Transpose input addr or output addr is null";
|
||||
} else if (res == NNACL_PARAM_INVALID) {
|
||||
MS_LOG(EXCEPTION) << "Transpose parameters are invalid.";
|
||||
if (res != NNACL_OK) {
|
||||
MS_LOG(ERROR) << "Transpose run failed";
|
||||
}
|
||||
} else {
|
||||
size_t data_count = (inputs[0]->size) / sizeof(T);
|
||||
|
|
|
@ -1284,7 +1284,7 @@ def onp_kron(x, y):
|
|||
return onp.kron(x, y)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -1428,7 +1428,7 @@ def onp_diff(input_array):
|
|||
return a, b, c, d, e, f, g
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -1770,7 +1770,7 @@ def test_cov():
|
|||
match_all_arrays(mnp_res, onp_res, error=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
|
Loading…
Reference in New Issue