!9477 Fix a bug that a and x cannot be broadcasted

From: @peixu_ren
Reviewed-by: @liangchenghui,@zichun_ye
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2020-12-06 02:22:37 +08:00 committed by Gitee
commit 715872a09e
1 changed files with 6 additions and 6 deletions

View File

@ -670,10 +670,10 @@ class IGamma(Cell):
use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a))
ax = a * self.log(x) - x - self.lgamma(a)
para_shape = self.shape(ax)
boradcastto = P.BroadcastTo(para_shape)
broadcastto = P.BroadcastTo(para_shape)
if para_shape != ():
x = boradcastto(x)
y = boradcastto(y)
x = broadcastto(x)
a = broadcastto(a)
x_is_zero = self.equal(x, 0)
log_maxfloat = self.log_maxfloat32
underflow = self.less(ax, self.neg(log_maxfloat))
@ -744,10 +744,10 @@ class LBeta(Cell):
_check_input_dtype("input_y", y_dtype, x_dtype, self.cls_name)
x_plus_y = x + y
para_shape = self.shape(x_plus_y)
boradcastto = P.BroadcastTo(para_shape)
broadcastto = P.BroadcastTo(para_shape)
if para_shape != ():
x = boradcastto(x)
y = boradcastto(y)
x = broadcastto(x)
y = broadcastto(y)
comp_less = self.less(x, y)
x_min = self.select(comp_less, x, y)
y_max = self.select(comp_less, y, x)