forked from mindspore-Ecosystem/mindspore
fix error message in checkTensor and bug in bernoulli cross entropy
This commit is contained in:
parent
0cb6d29f0c
commit
80d8f30a7c
|
@ -370,7 +370,7 @@ class CheckTensor(PrimitiveWithInfer):
|
|||
def __call__(self, x, name):
|
||||
if isinstance(x, Tensor):
|
||||
return x
|
||||
raise TypeError(f"For {name}, input type should be a Tensor.")
|
||||
raise TypeError(f"For {name}, input type should be a Tensor or Parameter.")
|
||||
|
||||
def common_dtype(arg_a, name_a, arg_b, name_b, hint_type):
|
||||
"""
|
||||
|
|
|
@ -186,8 +186,8 @@ class Bernoulli(Distribution):
|
|||
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
|
||||
"""
|
||||
probs1 = self._check_param(probs1)
|
||||
probs0 = 1 - probs1
|
||||
return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
|
||||
probs0 = 1.0 - probs1
|
||||
return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
|
||||
|
||||
def _cross_entropy(self, dist, probs1_b, probs1=None):
|
||||
"""
|
||||
|
|
|
@ -27,7 +27,7 @@ class Distribution(Cell):
|
|||
|
||||
Args:
|
||||
seed (int): random seed used in sampling.
|
||||
dtype (mindspore.dtype): type of the distribution.
|
||||
dtype (mindspore.dtype): the type of the event samples. Default: subclass dtype.
|
||||
name (str): Python str name prefixed to Ops created by this class. Default: subclass name.
|
||||
param (dict): parameters used to initialize the distribution.
|
||||
|
||||
|
|
Loading…
Reference in New Issue