forked from mindspore-Ecosystem/mindspore
fix minor issues
This commit is contained in:
parent
8ddb10fd8a
commit
855cb855af
|
@ -238,7 +238,8 @@ class Beta(Distribution):
|
|||
comp1 = self.greater(concentration1, 1.)
|
||||
comp2 = self.greater(concentration0, 1.)
|
||||
cond = self.logicaland(comp1, comp2)
|
||||
nan = self.fill(self.dtype, self.broadcast_shape, np.nan)
|
||||
batch_shape = self.shape(concentration1 + concentration0)
|
||||
nan = self.fill(self.dtype, batch_shape, np.nan)
|
||||
mode = (concentration1 - 1.) / (concentration1 + concentration0 - 2.)
|
||||
return self.select(cond, mode, nan)
|
||||
|
||||
|
|
|
@ -212,6 +212,7 @@ class Poisson(Distribution):
|
|||
"""
|
||||
value = self._check_value(value, "value")
|
||||
value = self.cast(value, self.dtype)
|
||||
value = self.floor(value)
|
||||
rate = self._check_param_type(rate)
|
||||
log_rate = self.log(rate)
|
||||
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
|
||||
|
@ -239,6 +240,7 @@ class Poisson(Distribution):
|
|||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.dtype)
|
||||
value = self.floor(value)
|
||||
rate = self._check_param_type(rate)
|
||||
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
|
||||
comp = self.less(value, zeros)
|
||||
|
@ -259,6 +261,9 @@ class Poisson(Distribution):
|
|||
"""
|
||||
shape = self.checktuple(shape, 'shape')
|
||||
rate = self._check_param_type(rate)
|
||||
|
||||
# now Poisson sampler supports only fp32
|
||||
rate = self.cast(rate, mstype.float32)
|
||||
origin_shape = shape + self.shape(rate)
|
||||
if origin_shape == ():
|
||||
sample_shape = (1,)
|
||||
|
|
Loading…
Reference in New Issue