forked from mindspore-Ecosystem/mindspore
Make const numbers to tensors to avoid a bug
This commit is contained in:
parent
48f83e9039
commit
0a8a5a9a91
|
@ -684,6 +684,7 @@ class LBeta(Cell):
|
|||
self.shape = P.Shape()
|
||||
self.dtype = P.DType()
|
||||
self.lgamma = LGamma()
|
||||
self.const = P.ScalarToTensor()
|
||||
|
||||
def construct(self, x, y):
|
||||
x_dtype = self.dtype(x)
|
||||
|
@ -714,9 +715,9 @@ class LBeta(Cell):
|
|||
log_gamma_correction_x_y = _log_gamma_correction(x_plus_y, self.minimax_coeff)
|
||||
|
||||
# Two large arguments case: y >= x >= 8.
|
||||
log_beta_two_large = 0.5 * self.log_2pi - 0.5 * self.log(y_max) \
|
||||
+ log_gamma_correction_x + log_gamma_correction_y - log_gamma_correction_x_y \
|
||||
+ (x_min - 0.5) * self.log(x_min / (x_min + y_max)) - y_max * self.log1p(x_min / y_max)
|
||||
log_beta_two_large = self.const(0.5 * self.log_2pi, x_dtype) - 0.5 * self.log(y_max) \
|
||||
+ log_gamma_correction_x + log_gamma_correction_y - log_gamma_correction_x_y \
|
||||
+ (x_min - 0.5) * self.log(x_min / (x_min + y_max)) - y_max * self.log1p(x_min / y_max)
|
||||
|
||||
cancelled_stirling = -1 * (x_min + y_max - 0.5) * self.log1p(x_min / y_max) - x_min * self.log(y_max) + x_min
|
||||
correction = log_gamma_correction_y - log_gamma_correction_x_y
|
||||
|
|
Loading…
Reference in New Issue