forked from mindspore-Ecosystem/mindspore
fix error message in checkTensor and bug in bernoulli cross entropy
This commit is contained in:
parent
0cb6d29f0c
commit
80d8f30a7c
|
@ -370,7 +370,7 @@ class CheckTensor(PrimitiveWithInfer):
|
||||||
def __call__(self, x, name):
|
def __call__(self, x, name):
|
||||||
if isinstance(x, Tensor):
|
if isinstance(x, Tensor):
|
||||||
return x
|
return x
|
||||||
raise TypeError(f"For {name}, input type should be a Tensor.")
|
raise TypeError(f"For {name}, input type should be a Tensor or Parameter.")
|
||||||
|
|
||||||
def common_dtype(arg_a, name_a, arg_b, name_b, hint_type):
|
def common_dtype(arg_a, name_a, arg_b, name_b, hint_type):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -186,8 +186,8 @@ class Bernoulli(Distribution):
|
||||||
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
|
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
|
||||||
"""
|
"""
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param(probs1)
|
||||||
probs0 = 1 - probs1
|
probs0 = 1.0 - probs1
|
||||||
return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
|
return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1))
|
||||||
|
|
||||||
def _cross_entropy(self, dist, probs1_b, probs1=None):
|
def _cross_entropy(self, dist, probs1_b, probs1=None):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -27,7 +27,7 @@ class Distribution(Cell):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed (int): random seed used in sampling.
|
seed (int): random seed used in sampling.
|
||||||
dtype (mindspore.dtype): type of the distribution.
|
dtype (mindspore.dtype): the type of the event samples. Default: subclass dtype.
|
||||||
name (str): Python str name prefixed to Ops created by this class. Default: subclass name.
|
name (str): Python str name prefixed to Ops created by this class. Default: subclass name.
|
||||||
param (dict): parameters used to initialize the distribution.
|
param (dict): parameters used to initialize the distribution.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue