diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/dynamic_broadcast_grad_args_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/dynamic_broadcast_grad_args_cpu_kernel.cc index 9ce89b8eccb..2e1a0499b58 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/dynamic_broadcast_grad_args_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/dynamic_broadcast_grad_args_cpu_kernel.cc @@ -107,32 +107,22 @@ bool DynamicBroadcastGradientArgsCpuKernelMod::LaunchKernel(const std::vector ranks = {input_size_list_[0] / sizeof(T), input_size_list_[1] / sizeof(T)}; std::vector> grad_reduce_idx(kDynamicBroadcastGradientArgsInputsNum); - bool all_equal = true; size_t max_rank = ranks[0] > ranks[1] ? ranks[0] : ranks[1]; - size_t min_rank = ranks[0] < ranks[1] ? ranks[0] : ranks[1]; - for (size_t i = 0; i < min_rank; i++) { - if (s0_addr[i] != s1_addr[i]) { - all_equal = false; - break; - } + std::vector> reverse_shapes(kDynamicBroadcastGradientArgsInputsNum); + for (size_t j = 0; j < ranks[0]; j++) { + reverse_shapes[0].push_back(s0_addr[ranks[0] - j - 1]); + } + if (reverse_shapes[0].size() < max_rank) { + reverse_shapes[0].resize(max_rank, 1); + } + for (size_t j = 0; j < ranks[1]; j++) { + reverse_shapes[1].push_back(s1_addr[ranks[1] - j - 1]); + } + if (reverse_shapes[1].size() < max_rank) { + reverse_shapes[1].resize(max_rank, 1); } - if (!all_equal) { - // Reverse shapes - std::vector> reverse_shapes(kDynamicBroadcastGradientArgsInputsNum); - for (size_t j = 0; j < ranks[0]; j++) { - reverse_shapes[0].push_back(s0_addr[ranks[0] - j - 1]); - } - if (reverse_shapes[0].size() < max_rank) { - reverse_shapes[0].resize(max_rank, 1); - } - - for (size_t j = 0; j < ranks[1]; j++) { - reverse_shapes[1].push_back(s1_addr[ranks[1] - j - 1]); - } - if (reverse_shapes[1].size() < max_rank) { - reverse_shapes[1].resize(max_rank, 1); - } + if (reverse_shapes[0] != reverse_shapes[1]) { grad_reduce_idx = GetGradIndex(reverse_shapes, max_rank); }