forked from mindspore-Ecosystem/mindspore
fix a bug in categorical distribution
This commit is contained in:
parent
2c9e634d0d
commit
4ea0e6f257
|
@ -175,7 +175,6 @@ class Categorical(Distribution):
|
||||||
self.squeeze_last_axis = P.Squeeze(-1)
|
self.squeeze_last_axis = P.Squeeze(-1)
|
||||||
self.square = P.Square()
|
self.square = P.Square()
|
||||||
self.transpose = P.Transpose()
|
self.transpose = P.Transpose()
|
||||||
self.is_nan = P.IsNan()
|
|
||||||
|
|
||||||
self.index_type = mstype.int32
|
self.index_type = mstype.int32
|
||||||
self.nan = np.nan
|
self.nan = np.nan
|
||||||
|
@ -291,10 +290,6 @@ class Categorical(Distribution):
|
||||||
value = self.cast(value, self.dtypeop(probs))
|
value = self.cast(value, self.dtypeop(probs))
|
||||||
|
|
||||||
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
|
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
|
||||||
neg_one = self.fill(self.dtypeop(value), self.shape(value), -1.0)
|
|
||||||
value = self.select(self.is_nan(value),
|
|
||||||
neg_one,
|
|
||||||
value)
|
|
||||||
between_zero_neone = self.logicand(self.less(value, 0,),
|
between_zero_neone = self.logicand(self.less(value, 0,),
|
||||||
self.greater(value, -1.))
|
self.greater(value, -1.))
|
||||||
value = self.select(between_zero_neone,
|
value = self.select(between_zero_neone,
|
||||||
|
@ -359,10 +354,6 @@ class Categorical(Distribution):
|
||||||
value = self.cast(value, self.dtypeop(probs))
|
value = self.cast(value, self.dtypeop(probs))
|
||||||
|
|
||||||
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
|
zeros = self.fill(self.dtypeop(value), self.shape(value), 0.0)
|
||||||
neg_one = self.fill(self.dtypeop(value), self.shape(value), -1.0)
|
|
||||||
value = self.select(self.is_nan(value),
|
|
||||||
neg_one,
|
|
||||||
value)
|
|
||||||
between_zero_neone = self.logicand(self.less(value, 0,),
|
between_zero_neone = self.logicand(self.less(value, 0,),
|
||||||
self.greater(value, -1.))
|
self.greater(value, -1.))
|
||||||
value = self.select(between_zero_neone,
|
value = self.select(between_zero_neone,
|
||||||
|
|
Loading…
Reference in New Issue