fixbug in DynamicBroadcastGradientArgsCpuKernel

This commit is contained in:
dayschan 2022-09-19 19:47:21 +08:00
parent d2e29f0818
commit 16b7f245eb
1 changed files with 13 additions and 23 deletions

View File

@ -107,32 +107,22 @@ bool DynamicBroadcastGradientArgsCpuKernelMod::LaunchKernel(const std::vector<ke
std::vector<size_t> ranks = {input_size_list_[0] / sizeof(T), input_size_list_[1] / sizeof(T)};
std::vector<std::vector<T>> grad_reduce_idx(kDynamicBroadcastGradientArgsInputsNum);
bool all_equal = true;
size_t max_rank = ranks[0] > ranks[1] ? ranks[0] : ranks[1];
size_t min_rank = ranks[0] < ranks[1] ? ranks[0] : ranks[1];
for (size_t i = 0; i < min_rank; i++) {
if (s0_addr[i] != s1_addr[i]) {
all_equal = false;
break;
}
std::vector<std::vector<T>> reverse_shapes(kDynamicBroadcastGradientArgsInputsNum);
for (size_t j = 0; j < ranks[0]; j++) {
reverse_shapes[0].push_back(s0_addr[ranks[0] - j - 1]);
}
if (reverse_shapes[0].size() < max_rank) {
reverse_shapes[0].resize(max_rank, 1);
}
for (size_t j = 0; j < ranks[1]; j++) {
reverse_shapes[1].push_back(s1_addr[ranks[1] - j - 1]);
}
if (reverse_shapes[1].size() < max_rank) {
reverse_shapes[1].resize(max_rank, 1);
}
if (!all_equal) {
// Reverse shapes
std::vector<std::vector<T>> reverse_shapes(kDynamicBroadcastGradientArgsInputsNum);
for (size_t j = 0; j < ranks[0]; j++) {
reverse_shapes[0].push_back(s0_addr[ranks[0] - j - 1]);
}
if (reverse_shapes[0].size() < max_rank) {
reverse_shapes[0].resize(max_rank, 1);
}
for (size_t j = 0; j < ranks[1]; j++) {
reverse_shapes[1].push_back(s1_addr[ranks[1] - j - 1]);
}
if (reverse_shapes[1].size() < max_rank) {
reverse_shapes[1].resize(max_rank, 1);
}
if (reverse_shapes[0] != reverse_shapes[1]) {
grad_reduce_idx = GetGradIndex(reverse_shapes, max_rank);
}