forked from mindspore-Ecosystem/mindspore
!9618 Modify the names of parameter check
From: @peixu_ren Reviewed-by: @zichun_ye,@sunnybeike Signed-off-by: @sunnybeike
This commit is contained in:
commit
e7555043bd
|
@ -245,7 +245,7 @@ class LGamma(Cell):
|
|||
|
||||
def construct(self, x):
|
||||
input_dtype = self.dtype(x)
|
||||
_check_input_dtype("input", input_dtype, [mstype.float16, mstype.float32], self.cls_name)
|
||||
_check_input_dtype("x", input_dtype, [mstype.float16, mstype.float32], self.cls_name)
|
||||
infinity = self.fill(input_dtype, self.shape(x), self.inf)
|
||||
|
||||
need_to_reflect = self.less(x, 0.5)
|
||||
|
@ -352,7 +352,7 @@ class DiGamma(Cell):
|
|||
|
||||
def construct(self, x):
|
||||
input_dtype = self.dtype(x)
|
||||
_check_input_dtype("input_x", input_dtype, [mstype.float16, mstype.float32], self.cls_name)
|
||||
_check_input_dtype("x", input_dtype, [mstype.float16, mstype.float32], self.cls_name)
|
||||
need_to_reflect = self.less(x, 0.5)
|
||||
neg_input = -x
|
||||
z = self.select(need_to_reflect, neg_input, x - 1)
|
||||
|
@ -612,8 +612,8 @@ class IGamma(Cell):
|
|||
def construct(self, a, x):
|
||||
a_dtype = self.dtype(a)
|
||||
x_dtype = self.dtype(x)
|
||||
_check_input_dtype("input_a", a_dtype, [mstype.float32], self.cls_name)
|
||||
_check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name)
|
||||
_check_input_dtype("a", a_dtype, [mstype.float32], self.cls_name)
|
||||
_check_input_dtype("x", x_dtype, a_dtype, self.cls_name)
|
||||
domain_error = self.logicalor(self.less(x, 0), self.less(a, 0))
|
||||
use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a))
|
||||
ax = a * self.log(x) - x - self.lgamma(a)
|
||||
|
@ -688,8 +688,8 @@ class LBeta(Cell):
|
|||
def construct(self, x, y):
|
||||
x_dtype = self.dtype(x)
|
||||
y_dtype = self.dtype(y)
|
||||
_check_input_dtype("input_x", x_dtype, [mstype.float16, mstype.float32], self.cls_name)
|
||||
_check_input_dtype("input_y", y_dtype, x_dtype, self.cls_name)
|
||||
_check_input_dtype("x", x_dtype, [mstype.float16, mstype.float32], self.cls_name)
|
||||
_check_input_dtype("y", y_dtype, x_dtype, self.cls_name)
|
||||
x_plus_y = x + y
|
||||
para_shape = self.shape(x_plus_y)
|
||||
broadcastto = P.BroadcastTo(para_shape)
|
||||
|
|
Loading…
Reference in New Issue