forked from mindspore-Ecosystem/mindspore
fix a bug in categorical distribution
This commit is contained in:
parent
2c9e634d0d
commit
4ea0e6f257
|
@ -175,7 +175,6 @@ class Categorical(Distribution):
|
|||
self.squeeze_last_axis = P.Squeeze(-1)
|
||||
self.square = P.Square()
|
||||
self.transpose = P.Transpose()
|
||||
self.is_nan = P.IsNan()
|
||||
|
||||
self.index_type = mstype.int32
|
||||
self.nan = np.nan
|
||||
|
@ -291,10 +290,6 @@ class Categorical(Distribution):
|
|||
value = self.cast(value, self.dtypeop(probs))
|
||||
|
||||
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
|
||||
neg_one = self.fill(self.dtypeop(value), self.shape(value), -1.0)
|
||||
value = self.select(self.is_nan(value),
|
||||
neg_one,
|
||||
value)
|
||||
between_zero_neone = self.logicand(self.less(value, 0,),
|
||||
self.greater(value, -1.))
|
||||
value = self.select(between_zero_neone,
|
||||
|
@ -359,10 +354,6 @@ class Categorical(Distribution):
|
|||
value = self.cast(value, self.dtypeop(probs))
|
||||
|
||||
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
|
||||
neg_one = self.fill(self.dtypeop(value), self.shape(value), -1.0)
|
||||
value = self.select(self.is_nan(value),
|
||||
neg_one,
|
||||
value)
|
||||
between_zero_neone = self.logicand(self.less(value, 0,),
|
||||
self.greater(value, -1.))
|
||||
value = self.select(between_zero_neone,
|
||||
|
|
Loading…
Reference in New Issue