fix argmin/max grad bprop expander issue

This commit is contained in:
dabaiji 2022-11-17 11:04:43 +08:00
parent 8edb949d35
commit 752585d6c0
1 changed files with 6 additions and 7 deletions

View File

@ -533,15 +533,10 @@ NodePtr MinOrMaxGrad(const BpropIRBuilder *ib, const NodePtr &x, const std::vect
NodePtr ArgminOrArgmaxGrad(const BpropIRBuilder *ib, const NodePtr &x, const int64_t &axis, const bool &keep_dims,
const NodePtr &out, const NodePtr &dout, const bool is_max) {
auto x_shape = ib->GetShape(x);
auto x_axis = axis;
auto x_axis = CheckRange(axis, SizeToLong(x_shape.size()));
auto onehot_axis = x_axis;
NodePtr dout_expand;
NodePtr new_out = out;
if (onehot_axis >= SizeToLong(x_shape.size())) {
onehot_axis = -1;
} else if (onehot_axis < -1) {
onehot_axis += SizeToLong(x_shape.size());
}
if (keep_dims) {
dout_expand = ib->TupleGetItem(dout, 1);
if (is_max) {
@ -552,11 +547,15 @@ NodePtr ArgminOrArgmaxGrad(const BpropIRBuilder *ib, const NodePtr &x, const int
} else {
dout_expand = ib->Emit("ExpandDims", {ib->TupleGetItem(dout, 1), ib->Value<int64_t>(onehot_axis)});
}
auto out_shape = ib->GetShape(ib->TupleGetItem(new_out, 0));
if (onehot_axis >= SizeToLong(out_shape.size())) {
onehot_axis = -1;
}
auto type_x = ib->GetDtype(x);
auto on_value = ib->Tensor(1, type_x);
auto off_value = ib->Tensor(0, type_x);
int64_t depth = x_shape[axis];
int64_t depth = x_shape[x_axis];
auto dx =
dout_expand * ib->Emit("OneHot", {ib->TupleGetItem(new_out, 0), ib->Value<int64_t>(depth), on_value, off_value},
{{"axis", MakeValue(onehot_axis)}});