forked from mindspore-Ecosystem/mindspore
!9477 Fix a bug that a and x cannot be broadcasted
From: @peixu_ren Reviewed-by: @liangchenghui,@zichun_ye Signed-off-by: @liangchenghui
This commit is contained in:
commit
715872a09e
|
@ -670,10 +670,10 @@ class IGamma(Cell):
|
|||
use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a))
|
||||
ax = a * self.log(x) - x - self.lgamma(a)
|
||||
para_shape = self.shape(ax)
|
||||
boradcastto = P.BroadcastTo(para_shape)
|
||||
broadcastto = P.BroadcastTo(para_shape)
|
||||
if para_shape != ():
|
||||
x = boradcastto(x)
|
||||
y = boradcastto(y)
|
||||
x = broadcastto(x)
|
||||
a = broadcastto(a)
|
||||
x_is_zero = self.equal(x, 0)
|
||||
log_maxfloat = self.log_maxfloat32
|
||||
underflow = self.less(ax, self.neg(log_maxfloat))
|
||||
|
@ -744,10 +744,10 @@ class LBeta(Cell):
|
|||
_check_input_dtype("input_y", y_dtype, x_dtype, self.cls_name)
|
||||
x_plus_y = x + y
|
||||
para_shape = self.shape(x_plus_y)
|
||||
boradcastto = P.BroadcastTo(para_shape)
|
||||
broadcastto = P.BroadcastTo(para_shape)
|
||||
if para_shape != ():
|
||||
x = boradcastto(x)
|
||||
y = boradcastto(y)
|
||||
x = broadcastto(x)
|
||||
y = broadcastto(y)
|
||||
comp_less = self.less(x, y)
|
||||
x_min = self.select(comp_less, x, y)
|
||||
y_max = self.select(comp_less, y, x)
|
||||
|
|
Loading…
Reference in New Issue