forked from mindspore-Ecosystem/mindspore
!49742 solve the bug of cumulativelogsumexp not support dynamic
Merge pull request !49742 from zong_shuai/cumulativelogexp
This commit is contained in:
commit
3e1290e2de
|
@ -1693,23 +1693,15 @@ REG_BPROP_BUILDER("CumulativeLogsumexp").SetBody(BODYFUNC(ib) {
|
|||
auto dout = ib->GetInput(kIndex3);
|
||||
bool reverse = GetValue<bool>(ib->GetAttr("reverse"));
|
||||
NodePtr dtype_min = nullptr;
|
||||
auto fp64_flag = false;
|
||||
if ((ib->GetDtype(x))->type_id() == TypeId::kNumberTypeFloat16) {
|
||||
dtype_min = ib->Tensor(-65500e+0, kFloat16);
|
||||
} else {
|
||||
if ((ib->GetDtype(x))->type_id() == TypeId::kNumberTypeFloat32) {
|
||||
dtype_min = ib->Tensor(-3.4028235e+38, kFloat32);
|
||||
} else {
|
||||
if ((ib->GetDtype(x))->type_id() == TypeId::kNumberTypeFloat64) {
|
||||
dout = ib->Cast(dout, kFloat32);
|
||||
x = ib->Cast(x, kFloat32);
|
||||
out = ib->Cast(out, kFloat32);
|
||||
dtype_min = ib->Tensor(-3.4028235e+38, kFloat32);
|
||||
fp64_flag = true;
|
||||
dtype_min = ib->Tensor(-1.7976931348623157e+308, kFloat64);
|
||||
}
|
||||
}
|
||||
}
|
||||
dtype_min = ib->Emit("BroadcastTo", {dtype_min}, {{"shape", MakeValue(ib->GetShape(dout))}});
|
||||
auto log_grad_positive = ib->Select(ib->Greater(dout, ib->Tensor(0, ib->GetDtype(dout))), ib->Log(dout), dtype_min);
|
||||
auto log_grad_negative =
|
||||
ib->Select(ib->Less(dout, ib->Tensor(0, ib->GetDtype(dout))), ib->Log(ib->Neg(dout)), dtype_min);
|
||||
|
@ -1721,11 +1713,6 @@ REG_BPROP_BUILDER("CumulativeLogsumexp").SetBody(BODYFUNC(ib) {
|
|||
ib->Exp(ib->Add((ib->Emit("CumulativeLogsumexp", {ib->Sub(log_grad_negative, out), axis},
|
||||
{{"exclusive", ib->GetAttr("exclusive")}, {"reverse", MakeValue(!reverse)}})),
|
||||
x));
|
||||
if (fp64_flag) {
|
||||
output_pos = ib->Cast(output_pos, kFloat64);
|
||||
output_neg = ib->Cast(output_neg, kFloat64);
|
||||
x = ib->Cast(x, kFloat64);
|
||||
}
|
||||
return {ib->Sub(output_pos, output_neg), ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from mindspore.ops.functional import broadcast_gradient_args
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.operations.math_ops import Trace, Bernoulli, Renorm
|
||||
from mindspore import nn, ops, Tensor
|
||||
from mindspore import nn, Tensor
|
||||
from mindspore.ops.operations.math_ops import Real, Imag, Complex, Angle
|
||||
from mindspore.ops.operations.math_ops import Polar
|
||||
from mindspore.ops.operations.math_ops import ComplexAbs
|
||||
|
@ -468,19 +468,6 @@ def get_brop_cumulative_logsumexp(self):
|
|||
less_op = P.Less()
|
||||
neg_op = P.Neg()
|
||||
|
||||
def where_v2(condition, x=None, y=None):
|
||||
return_all = None
|
||||
if x is None and y is None:
|
||||
return_all = mnp.where(condition, x, y)
|
||||
elif x is not None and y is not None:
|
||||
shape_ = x.shape
|
||||
input_y = np.resize(y, shape_)
|
||||
input_y = Tensor(input_y).astype(x.dtype)
|
||||
return_all = ops.select(condition, x, input_y)
|
||||
else:
|
||||
raise ValueError("x and y must both be non-None or both be None.")
|
||||
return return_all
|
||||
|
||||
def bprop(x, axis, out, dout):
|
||||
dtype_min = 0
|
||||
if x.dtype == mstype.float16:
|
||||
|
@ -489,8 +476,8 @@ def get_brop_cumulative_logsumexp(self):
|
|||
dtype_min = -3.4028235e+38
|
||||
elif x.dtype == mstype.float64:
|
||||
dtype_min = -1.7976931348623157e+308
|
||||
log_grad_positive = where_v2(greater_op(dout, 0), log_op(dout), dtype_min)
|
||||
log_grad_negative = where_v2(less_op(dout, 0), log_op(neg_op(dout)), dtype_min)
|
||||
log_grad_positive = mnp.where(greater_op(dout, 0), log_op(dout), dtype_min)
|
||||
log_grad_negative = mnp.where(less_op(dout, 0), log_op(neg_op(dout)), dtype_min)
|
||||
output_pos = exp_op(cumulative_op(log_grad_positive - out, axis) + x)
|
||||
output_neg = exp_op(cumulative_op(log_grad_negative - out, axis) + x)
|
||||
return (output_pos - output_neg, zeros_like(x))
|
||||
|
|
Loading…
Reference in New Issue