fix minimumgrad and maximumgrad operator bug caused by student

This commit is contained in:
shen_jingxing 2022-12-14 17:22:18 +08:00
parent bdeeb90233
commit dcfdc5cdd8
2 changed files with 2 additions and 2 deletions

View File

@ -100,7 +100,7 @@ void MaximumGradRecTask(const T *x, const T *y, const T *dout, T *dx, T *dy, siz
size_t dout_i = i * dout_cargo[dim];
if (dim == dout_shape.size() - 1) {
if (*(x + x_index + x_i) >= *(y + y_index + y_i)) {
if (*(x + x_index + x_i) > *(y + y_index + y_i)) {
*(dx + x_index + x_i) += *(dout + dout_index + i);
} else {
*(dy + y_index + y_i) += *(dout + dout_index + i);

View File

@ -149,7 +149,7 @@ void MinimumGradRecTask(const T *x, const T *y, const T *dout, T *dx, T *dy, con
size_t dout_i = i * dout_cargo[dim];
if (dim == dout_shape.size() - 1) {
if (*(x + x_index + x_i) <= *(y + y_index + y_i)) {
if (*(x + x_index + x_i) < *(y + y_index + y_i)) {
*(dx + x_index + x_i) += *(dout + dout_index + i);
} else {
*(dy + y_index + y_i) += *(dout + dout_index + i);