fix error message in checkTensor and bug in bernoulli cross entropy

This commit is contained in:
Xun Deng 2020-09-01 15:09:21 -04:00
parent 0cb6d29f0c
commit 80d8f30a7c
3 changed files with 4 additions and 4 deletions

View File

@ -370,7 +370,7 @@ class CheckTensor(PrimitiveWithInfer):
def __call__(self, x, name):
if isinstance(x, Tensor):
return x
raise TypeError(f"For {name}, input type should be a Tensor.")
raise TypeError(f"For {name}, input type should be a Tensor or Parameter.")
def common_dtype(arg_a, name_a, arg_b, name_b, hint_type):
"""

View File

@ -186,8 +186,8 @@ class Bernoulli(Distribution):
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
"""
probs1 = self._check_param(probs1)
probs0 = 1 - probs1
return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
probs0 = 1.0 - probs1
return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
def _cross_entropy(self, dist, probs1_b, probs1=None):
"""

View File

@ -27,7 +27,7 @@ class Distribution(Cell):
Args:
seed (int): random seed used in sampling.
dtype (mindspore.dtype): type of the distribution.
dtype (mindspore.dtype): the type of the event samples. Default: subclass dtype.
name (str): Python str name prefixed to Ops created by this class. Default: subclass name.
param (dict): parameters used to initialize the distribution.