!54124 Fix grad of indices for Gather.
Merge pull request !54124 from TronZhang/fix_indices_gather
This commit is contained in:
commit
2cbd5ae708
|
@ -422,6 +422,7 @@ NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
|
|||
auto batch_dims = ib->GetAttr<int64_t>(kAttrBatchDims);
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto ori_indices = indices; // indices may be changed latter.
|
||||
auto axis = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
auto x_shp = ib->GetShape(x);
|
||||
|
@ -471,7 +472,7 @@ NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
|
|||
x_grad = ib->UnsortedSegmentSum(values_transpose, indices, num_segment);
|
||||
}
|
||||
x_grad = ib->Transpose(x_grad, perm_2);
|
||||
return {x_grad, ib->ZerosLike(indices), ib->ZerosLike(axis)};
|
||||
return {x_grad, ib->ZerosLike(ori_indices), ib->ZerosLike(axis)};
|
||||
}
|
||||
|
||||
if (ind_shp.empty()) {
|
||||
|
@ -493,7 +494,7 @@ NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
|
|||
}
|
||||
auto perm_2 = GenerateInverseIndex(x_shp, axis_v, batch_dims);
|
||||
auto params_grad = ib->Transpose(x_grad, perm_2);
|
||||
return {params_grad, ib->ZerosLike(indices), ib->ZerosLike(axis)};
|
||||
return {params_grad, ib->ZerosLike(ori_indices), ib->ZerosLike(axis)};
|
||||
}
|
||||
|
||||
ShapeArray ConcatOffsetCal(const ShapeArray &input_shapes, size_t axis_s) {
|
||||
|
|
Loading…
Reference in New Issue