!54124 Fix grad of indices for Gather.

Merge pull request !54124 from TronZhang/fix_indices_gather
This commit is contained in:
i-robot 2023-05-15 03:30:42 +00:00 committed by Gitee
commit 2cbd5ae708
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 3 additions and 2 deletions

View File

@ -422,6 +422,7 @@ NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
auto batch_dims = ib->GetAttr<int64_t>(kAttrBatchDims);
auto x = ib->GetInput(kIndex0);
auto indices = ib->GetInput(kIndex1);
auto ori_indices = indices; // indices may be changed latter.
auto axis = ib->GetInput(kIndex2);
auto dout = ib->GetInput(kIndex4);
auto x_shp = ib->GetShape(x);
@ -471,7 +472,7 @@ NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
x_grad = ib->UnsortedSegmentSum(values_transpose, indices, num_segment);
}
x_grad = ib->Transpose(x_grad, perm_2);
return {x_grad, ib->ZerosLike(indices), ib->ZerosLike(axis)};
return {x_grad, ib->ZerosLike(ori_indices), ib->ZerosLike(axis)};
}
if (ind_shp.empty()) {
@ -493,7 +494,7 @@ NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
}
auto perm_2 = GenerateInverseIndex(x_shp, axis_v, batch_dims);
auto params_grad = ib->Transpose(x_grad, perm_2);
return {params_grad, ib->ZerosLike(indices), ib->ZerosLike(axis)};
return {params_grad, ib->ZerosLike(ori_indices), ib->ZerosLike(axis)};
}
ShapeArray ConcatOffsetCal(const ShapeArray &input_shapes, size_t axis_s) {