fix minor issues

This commit is contained in:
Xun Deng 2020-12-09 16:17:56 -05:00
parent 8ddb10fd8a
commit 855cb855af
2 changed files with 7 additions and 1 deletions

View File

@ -238,7 +238,8 @@ class Beta(Distribution):
comp1 = self.greater(concentration1, 1.)
comp2 = self.greater(concentration0, 1.)
cond = self.logicaland(comp1, comp2)
nan = self.fill(self.dtype, self.broadcast_shape, np.nan)
batch_shape = self.shape(concentration1 + concentration0)
nan = self.fill(self.dtype, batch_shape, np.nan)
mode = (concentration1 - 1.) / (concentration1 + concentration0 - 2.)
return self.select(cond, mode, nan)

View File

@ -212,6 +212,7 @@ class Poisson(Distribution):
"""
value = self._check_value(value, "value")
value = self.cast(value, self.dtype)
value = self.floor(value)
rate = self._check_param_type(rate)
log_rate = self.log(rate)
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
@ -239,6 +240,7 @@ class Poisson(Distribution):
"""
value = self._check_value(value, 'value')
value = self.cast(value, self.dtype)
value = self.floor(value)
rate = self._check_param_type(rate)
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
comp = self.less(value, zeros)
@ -259,6 +261,9 @@ class Poisson(Distribution):
"""
shape = self.checktuple(shape, 'shape')
rate = self._check_param_type(rate)
# now Poisson sampler supports only fp32
rate = self.cast(rate, mstype.float32)
origin_shape = shape + self.shape(rate)
if origin_shape == ():
sample_shape = (1,)