forked from mindspore-Ecosystem/mindspore
fixbug in DynamicBroadcastGradientArgsCpuKernel
This commit is contained in:
parent
d2e29f0818
commit
16b7f245eb
|
@ -107,32 +107,22 @@ bool DynamicBroadcastGradientArgsCpuKernelMod::LaunchKernel(const std::vector<ke
|
|||
std::vector<size_t> ranks = {input_size_list_[0] / sizeof(T), input_size_list_[1] / sizeof(T)};
|
||||
|
||||
std::vector<std::vector<T>> grad_reduce_idx(kDynamicBroadcastGradientArgsInputsNum);
|
||||
bool all_equal = true;
|
||||
size_t max_rank = ranks[0] > ranks[1] ? ranks[0] : ranks[1];
|
||||
size_t min_rank = ranks[0] < ranks[1] ? ranks[0] : ranks[1];
|
||||
for (size_t i = 0; i < min_rank; i++) {
|
||||
if (s0_addr[i] != s1_addr[i]) {
|
||||
all_equal = false;
|
||||
break;
|
||||
}
|
||||
std::vector<std::vector<T>> reverse_shapes(kDynamicBroadcastGradientArgsInputsNum);
|
||||
for (size_t j = 0; j < ranks[0]; j++) {
|
||||
reverse_shapes[0].push_back(s0_addr[ranks[0] - j - 1]);
|
||||
}
|
||||
if (reverse_shapes[0].size() < max_rank) {
|
||||
reverse_shapes[0].resize(max_rank, 1);
|
||||
}
|
||||
for (size_t j = 0; j < ranks[1]; j++) {
|
||||
reverse_shapes[1].push_back(s1_addr[ranks[1] - j - 1]);
|
||||
}
|
||||
if (reverse_shapes[1].size() < max_rank) {
|
||||
reverse_shapes[1].resize(max_rank, 1);
|
||||
}
|
||||
if (!all_equal) {
|
||||
// Reverse shapes
|
||||
std::vector<std::vector<T>> reverse_shapes(kDynamicBroadcastGradientArgsInputsNum);
|
||||
for (size_t j = 0; j < ranks[0]; j++) {
|
||||
reverse_shapes[0].push_back(s0_addr[ranks[0] - j - 1]);
|
||||
}
|
||||
if (reverse_shapes[0].size() < max_rank) {
|
||||
reverse_shapes[0].resize(max_rank, 1);
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < ranks[1]; j++) {
|
||||
reverse_shapes[1].push_back(s1_addr[ranks[1] - j - 1]);
|
||||
}
|
||||
if (reverse_shapes[1].size() < max_rank) {
|
||||
reverse_shapes[1].resize(max_rank, 1);
|
||||
}
|
||||
|
||||
if (reverse_shapes[0] != reverse_shapes[1]) {
|
||||
grad_reduce_idx = GetGradIndex(reverse_shapes, max_rank);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue