From 80d8f30a7c8d07ae7cfbbcb384c28eb685da5c91 Mon Sep 17 00:00:00 2001 From: Xun Deng Date: Tue, 1 Sep 2020 15:09:21 -0400 Subject: [PATCH] fix error message in checkTensor and bug in bernoulli cross entropy --- mindspore/nn/probability/distribution/_utils/utils.py | 2 +- mindspore/nn/probability/distribution/bernoulli.py | 4 ++-- mindspore/nn/probability/distribution/distribution.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 729c42634d7..2908f33b6b5 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -370,7 +370,7 @@ class CheckTensor(PrimitiveWithInfer): def __call__(self, x, name): if isinstance(x, Tensor): return x - raise TypeError(f"For {name}, input type should be a Tensor.") + raise TypeError(f"For {name}, input type should be a Tensor or Parameter.") def common_dtype(arg_a, name_a, arg_b, name_b, hint_type): """ diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index 0dcbc59689b..7fce4b7802f 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -186,8 +186,8 @@ class Bernoulli(Distribution): H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1) """ probs1 = self._check_param(probs1) - probs0 = 1 - probs1 - return -1 * (probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) + probs0 = 1.0 - probs1 + return -(probs0 * self.log(probs0)) - (probs1 * self.log(probs1)) def _cross_entropy(self, dist, probs1_b, probs1=None): """ diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index dcb904aeac5..943b022057d 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -27,7 +27,7 @@ class Distribution(Cell): Args: seed (int): random seed used in sampling. - dtype (mindspore.dtype): type of the distribution. + dtype (mindspore.dtype): the type of the event samples. Default: subclass dtype. name (str): Python str name prefixed to Ops created by this class. Default: subclass name. param (dict): parameters used to initialize the distribution.