!49742 solve the bug of cumulativelogsumexp not support dynamic

Merge pull request !49742 from zong_shuai/cumulativelogexp
This commit is contained in:
i-robot 2023-03-06 03:03:59 +00:00 committed by Gitee
commit 3e1290e2de
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 4 additions and 30 deletions

View File

@ -1693,23 +1693,15 @@ REG_BPROP_BUILDER("CumulativeLogsumexp").SetBody(BODYFUNC(ib) {
auto dout = ib->GetInput(kIndex3);
bool reverse = GetValue<bool>(ib->GetAttr("reverse"));
NodePtr dtype_min = nullptr;
auto fp64_flag = false;
if ((ib->GetDtype(x))->type_id() == TypeId::kNumberTypeFloat16) {
dtype_min = ib->Tensor(-65500e+0, kFloat16);
} else {
if ((ib->GetDtype(x))->type_id() == TypeId::kNumberTypeFloat32) {
dtype_min = ib->Tensor(-3.4028235e+38, kFloat32);
} else {
if ((ib->GetDtype(x))->type_id() == TypeId::kNumberTypeFloat64) {
dout = ib->Cast(dout, kFloat32);
x = ib->Cast(x, kFloat32);
out = ib->Cast(out, kFloat32);
dtype_min = ib->Tensor(-3.4028235e+38, kFloat32);
fp64_flag = true;
}
dtype_min = ib->Tensor(-1.7976931348623157e+308, kFloat64);
}
}
dtype_min = ib->Emit("BroadcastTo", {dtype_min}, {{"shape", MakeValue(ib->GetShape(dout))}});
auto log_grad_positive = ib->Select(ib->Greater(dout, ib->Tensor(0, ib->GetDtype(dout))), ib->Log(dout), dtype_min);
auto log_grad_negative =
ib->Select(ib->Less(dout, ib->Tensor(0, ib->GetDtype(dout))), ib->Log(ib->Neg(dout)), dtype_min);
@ -1721,11 +1713,6 @@ REG_BPROP_BUILDER("CumulativeLogsumexp").SetBody(BODYFUNC(ib) {
ib->Exp(ib->Add((ib->Emit("CumulativeLogsumexp", {ib->Sub(log_grad_negative, out), axis},
{{"exclusive", ib->GetAttr("exclusive")}, {"reverse", MakeValue(!reverse)}})),
x));
if (fp64_flag) {
output_pos = ib->Cast(output_pos, kFloat64);
output_neg = ib->Cast(output_neg, kFloat64);
x = ib->Cast(x, kFloat64);
}
return {ib->Sub(output_pos, output_neg), ib->ZerosLike(x)};
});

View File

@ -25,7 +25,7 @@ from mindspore.ops.functional import broadcast_gradient_args
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations.math_ops import Trace, Bernoulli, Renorm
from mindspore import nn, ops, Tensor
from mindspore import nn, Tensor
from mindspore.ops.operations.math_ops import Real, Imag, Complex, Angle
from mindspore.ops.operations.math_ops import Polar
from mindspore.ops.operations.math_ops import ComplexAbs
@ -468,19 +468,6 @@ def get_brop_cumulative_logsumexp(self):
less_op = P.Less()
neg_op = P.Neg()
def where_v2(condition, x=None, y=None):
return_all = None
if x is None and y is None:
return_all = mnp.where(condition, x, y)
elif x is not None and y is not None:
shape_ = x.shape
input_y = np.resize(y, shape_)
input_y = Tensor(input_y).astype(x.dtype)
return_all = ops.select(condition, x, input_y)
else:
raise ValueError("x and y must both be non-None or both be None.")
return return_all
def bprop(x, axis, out, dout):
dtype_min = 0
if x.dtype == mstype.float16:
@ -489,8 +476,8 @@ def get_brop_cumulative_logsumexp(self):
dtype_min = -3.4028235e+38
elif x.dtype == mstype.float64:
dtype_min = -1.7976931348623157e+308
log_grad_positive = where_v2(greater_op(dout, 0), log_op(dout), dtype_min)
log_grad_negative = where_v2(less_op(dout, 0), log_op(neg_op(dout)), dtype_min)
log_grad_positive = mnp.where(greater_op(dout, 0), log_op(dout), dtype_min)
log_grad_negative = mnp.where(less_op(dout, 0), log_op(neg_op(dout)), dtype_min)
output_pos = exp_op(cumulative_op(log_grad_positive - out, axis) + x)
output_neg = exp_op(cumulative_op(log_grad_negative - out, axis) + x)
return (output_pos - output_neg, zeros_like(x))