From dc11fa9f53c61d691d1686864ad7616f388b8452 Mon Sep 17 00:00:00 2001 From: Xun Deng Date: Fri, 21 Aug 2020 13:50:01 -0400 Subject: [PATCH] Fixed CheckTuple issues and error message --- .../probability/distribution/_utils/utils.py | 19 +++++++++++++------ .../nn/probability/distribution/bernoulli.py | 6 +----- .../probability/distribution/distribution.py | 4 ++++ .../probability/distribution/exponential.py | 5 +---- .../nn/probability/distribution/geometric.py | 5 +---- .../nn/probability/distribution/normal.py | 6 +----- .../distribution/transformed_distribution.py | 2 +- .../nn/probability/distribution/uniform.py | 6 +----- 8 files changed, 23 insertions(+), 30 deletions(-) diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 74c265b3672..2da4ca30d1e 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -22,6 +22,7 @@ from mindspore.common.parameter import Parameter from mindspore.common import dtype as mstype from mindspore.ops import operations as P from mindspore.ops import composite as C +from mindspore import context import mindspore.nn as nn import mindspore.nn.probability as msp @@ -273,7 +274,8 @@ def check_type(data_type, value_type, name): @constexpr def raise_none_error(name): - raise ValueError(f"{name} should be specified. Value cannot be None") + raise TypeError(f"the type {name} should be subclass of Tensor." + f" It should not be None since it is not specified during initialization.") @constexpr def raise_not_impl_error(name): @@ -298,15 +300,20 @@ class CheckTuple(PrimitiveWithInfer): def __infer__(self, x, name): if not isinstance(x['dtype'], tuple): - raise TypeError("Input type should be a tuple: " + name["value"]) + raise TypeError(f"For {name['value']}, Input type should b a tuple.") out = {'shape': None, 'dtype': None, - 'value': None} + 'value': x["value"]} return out - def __call__(self, *args): - return + def __call__(self, x, name): + if context.get_context("mode") == 0: + return x["value"] + #Pynative mode + if isinstance(x, tuple): + return x + raise TypeError(f"For {name['value']}, Input type should b a tuple.") class CheckTensor(PrimitiveWithInfer): """ @@ -327,5 +334,5 @@ class CheckTensor(PrimitiveWithInfer): 'value': None} return out - def __call__(self, *args): + def __call__(self, x, name): return diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index 512a935ba8f..2ef9ed83215 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -18,7 +18,6 @@ from mindspore.ops import operations as P from mindspore.ops import composite as C from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error -from ._utils.utils import CheckTensor, CheckTuple from ._utils.custom_ops import log_by_step class Bernoulli(Distribution): @@ -125,9 +124,6 @@ class Bernoulli(Distribution): self.sqrt = P.Sqrt() self.uniform = C.uniform - self.checktensor = CheckTensor() - self.checktuple = CheckTuple() - def extend_repr(self): if self.is_scalar_batch: str_info = f'probs = {self.probs}' @@ -279,7 +275,7 @@ class Bernoulli(Distribution): Returns: Tensor, shape is shape + batch_shape. """ - self.checktuple(shape, 'shape') + shape = self.checktuple(shape, 'shape') probs1 = self._check_param(probs1) origin_shape = shape + self.shape(probs1) if origin_shape == (): diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index fffc7ed69e9..7a1385daede 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -17,6 +17,7 @@ from mindspore.nn.cell import Cell from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param +from ._utils.utils import CheckTuple, CheckTensor class Distribution(Cell): """ @@ -79,6 +80,9 @@ class Distribution(Cell): self._set_log_survival() self._set_cross_entropy() + self.checktuple = CheckTuple() + self.checktensor = CheckTensor() + @property def name(self): return self._name diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py index 5a6ada38d5a..1311a43c585 100644 --- a/mindspore/nn/probability/distribution/exponential.py +++ b/mindspore/nn/probability/distribution/exponential.py @@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\ raise_none_error -from ._utils.utils import CheckTensor, CheckTuple from ._utils.custom_ops import log_by_step class Exponential(Distribution): @@ -127,8 +126,6 @@ class Exponential(Distribution): self.sq = P.Square() self.uniform = C.uniform - self.checktensor = CheckTensor() - self.checktuple = CheckTuple() def extend_repr(self): if self.is_scalar_batch: @@ -270,7 +267,7 @@ class Exponential(Distribution): Returns: Tensor, shape is shape + batch_shape. """ - self.checktuple(shape, 'shape') + shape = self.checktuple(shape, 'shape') rate = self._check_param(rate) origin_shape = shape + self.shape(rate) if origin_shape == (): diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index 9b1d8669665..8065af53a57 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\ raise_none_error -from ._utils.utils import CheckTensor, CheckTuple from ._utils.custom_ops import log_by_step class Geometric(Distribution): @@ -131,8 +130,6 @@ class Geometric(Distribution): self.sqrt = P.Sqrt() self.uniform = C.uniform - self.checktensor = CheckTensor() - self.checktuple = CheckTuple() def extend_repr(self): if self.is_scalar_batch: @@ -278,7 +275,7 @@ class Geometric(Distribution): Returns: Tensor, shape is shape + batch_shape. """ - self.checktuple(shape, 'shape') + shape = self.checktuple(shape, 'shape') probs1 = self._check_param(probs1) origin_shape = shape + self.shape(probs1) if origin_shape == (): diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py index fc9a35908d5..9993c5ec093 100644 --- a/mindspore/nn/probability/distribution/normal.py +++ b/mindspore/nn/probability/distribution/normal.py @@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\ raise_none_error -from ._utils.utils import CheckTensor, CheckTuple from ._utils.custom_ops import log_by_step, expm1_by_step class Normal(Distribution): @@ -128,9 +127,6 @@ class Normal(Distribution): self.sqrt = P.Sqrt() self.zeroslike = P.ZerosLike() - self.checktensor = CheckTensor() - self.checktuple = CheckTuple() - def extend_repr(self): if self.is_scalar_batch: str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' @@ -277,7 +273,7 @@ class Normal(Distribution): Returns: Tensor, shape is shape + batch_shape. """ - self.checktuple(shape, 'shape') + shape = self.checktuple(shape, 'shape') mean, sd = self._check_param(mean, sd) batch_shape = self.shape(mean + sd) origin_shape = shape + batch_shape diff --git a/mindspore/nn/probability/distribution/transformed_distribution.py b/mindspore/nn/probability/distribution/transformed_distribution.py index 850df02e14a..9ca9f6bdf13 100644 --- a/mindspore/nn/probability/distribution/transformed_distribution.py +++ b/mindspore/nn/probability/distribution/transformed_distribution.py @@ -116,4 +116,4 @@ class TransformedDistribution(Distribution): if not self.is_linear_transformation: raise_not_impl_error("mean") - return self.bijector("forward", self.distribution("mean")) + return self.bijector("forward", self.distribution("mean", *args, **kwargs)) diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index 0d1b96c9e67..d5d3aa6f34f 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -19,7 +19,6 @@ from mindspore.common import dtype as mstype from .distribution import Distribution from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\ raise_none_error -from ._utils.utils import CheckTensor, CheckTuple from ._utils.custom_ops import log_by_step class Uniform(Distribution): @@ -131,9 +130,6 @@ class Uniform(Distribution): self.zeroslike = P.ZerosLike() self.uniform = C.uniform - self.checktensor = CheckTensor() - self.checktuple = CheckTuple() - def extend_repr(self): if self.is_scalar_batch: str_info = f'low = {self.low}, high = {self.high}' @@ -306,7 +302,7 @@ class Uniform(Distribution): Returns: Tensor, shape is shape + batch_shape. """ - self.checktuple(shape, 'shape') + shape = self.checktuple(shape, 'shape') low, high = self._check_param(low, high) broadcast_shape = self.shape(low + high) origin_shape = shape + broadcast_shape