fix argmin/max grad bprop expander issue
This commit is contained in:
parent
8edb949d35
commit
752585d6c0
|
@ -533,15 +533,10 @@ NodePtr MinOrMaxGrad(const BpropIRBuilder *ib, const NodePtr &x, const std::vect
|
|||
NodePtr ArgminOrArgmaxGrad(const BpropIRBuilder *ib, const NodePtr &x, const int64_t &axis, const bool &keep_dims,
|
||||
const NodePtr &out, const NodePtr &dout, const bool is_max) {
|
||||
auto x_shape = ib->GetShape(x);
|
||||
auto x_axis = axis;
|
||||
auto x_axis = CheckRange(axis, SizeToLong(x_shape.size()));
|
||||
auto onehot_axis = x_axis;
|
||||
NodePtr dout_expand;
|
||||
NodePtr new_out = out;
|
||||
if (onehot_axis >= SizeToLong(x_shape.size())) {
|
||||
onehot_axis = -1;
|
||||
} else if (onehot_axis < -1) {
|
||||
onehot_axis += SizeToLong(x_shape.size());
|
||||
}
|
||||
if (keep_dims) {
|
||||
dout_expand = ib->TupleGetItem(dout, 1);
|
||||
if (is_max) {
|
||||
|
@ -552,11 +547,15 @@ NodePtr ArgminOrArgmaxGrad(const BpropIRBuilder *ib, const NodePtr &x, const int
|
|||
} else {
|
||||
dout_expand = ib->Emit("ExpandDims", {ib->TupleGetItem(dout, 1), ib->Value<int64_t>(onehot_axis)});
|
||||
}
|
||||
auto out_shape = ib->GetShape(ib->TupleGetItem(new_out, 0));
|
||||
if (onehot_axis >= SizeToLong(out_shape.size())) {
|
||||
onehot_axis = -1;
|
||||
}
|
||||
|
||||
auto type_x = ib->GetDtype(x);
|
||||
auto on_value = ib->Tensor(1, type_x);
|
||||
auto off_value = ib->Tensor(0, type_x);
|
||||
int64_t depth = x_shape[axis];
|
||||
int64_t depth = x_shape[x_axis];
|
||||
auto dx =
|
||||
dout_expand * ib->Emit("OneHot", {ib->TupleGetItem(new_out, 0), ib->Value<int64_t>(depth), on_value, off_value},
|
||||
{{"axis", MakeValue(onehot_axis)}});
|
||||
|
|
Loading…
Reference in New Issue