fix a bug in categorical distribution

This commit is contained in:
Xun Deng 2020-12-14 20:04:09 -05:00
parent 2c9e634d0d
commit 4ea0e6f257
1 changed files with 0 additions and 9 deletions

View File

@ -175,7 +175,6 @@ class Categorical(Distribution):
self.squeeze_last_axis = P.Squeeze(-1) self.squeeze_last_axis = P.Squeeze(-1)
self.square = P.Square() self.square = P.Square()
self.transpose = P.Transpose() self.transpose = P.Transpose()
self.is_nan = P.IsNan()
self.index_type = mstype.int32 self.index_type = mstype.int32
self.nan = np.nan self.nan = np.nan
@ -291,10 +290,6 @@ class Categorical(Distribution):
value = self.cast(value, self.dtypeop(probs)) value = self.cast(value, self.dtypeop(probs))
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0) zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
neg_one = self.fill(self.dtypeop(value), self.shape(value), -1.0)
value = self.select(self.is_nan(value),
neg_one,
value)
between_zero_neone = self.logicand(self.less(value, 0,), between_zero_neone = self.logicand(self.less(value, 0,),
self.greater(value, -1.)) self.greater(value, -1.))
value = self.select(between_zero_neone, value = self.select(between_zero_neone,
@ -359,10 +354,6 @@ class Categorical(Distribution):
value = self.cast(value, self.dtypeop(probs)) value = self.cast(value, self.dtypeop(probs))
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0) zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
neg_one = self.fill(self.dtypeop(value), self.shape(value), -1.0)
value = self.select(self.is_nan(value),
neg_one,
value)
between_zero_neone = self.logicand(self.less(value, 0,), between_zero_neone = self.logicand(self.less(value, 0,),
self.greater(value, -1.)) self.greater(value, -1.))
value = self.select(between_zero_neone, value = self.select(between_zero_neone,