forked from mindspore-Ecosystem/mindspore
!4956 Fix CheckTuple in pynative mode
Merge pull request !4956 from XunDeng/pp_issue_branch
This commit is contained in:
commit
9f19076788
|
@ -22,6 +22,7 @@ from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore import context
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
import mindspore.nn.probability as msp
|
import mindspore.nn.probability as msp
|
||||||
|
|
||||||
|
@ -273,7 +274,8 @@ def check_type(data_type, value_type, name):
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def raise_none_error(name):
|
def raise_none_error(name):
|
||||||
raise ValueError(f"{name} should be specified. Value cannot be None")
|
raise TypeError(f"the type {name} should be subclass of Tensor."
|
||||||
|
f" It should not be None since it is not specified during initialization.")
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def raise_not_impl_error(name):
|
def raise_not_impl_error(name):
|
||||||
|
@ -298,15 +300,20 @@ class CheckTuple(PrimitiveWithInfer):
|
||||||
|
|
||||||
def __infer__(self, x, name):
|
def __infer__(self, x, name):
|
||||||
if not isinstance(x['dtype'], tuple):
|
if not isinstance(x['dtype'], tuple):
|
||||||
raise TypeError("Input type should be a tuple: " + name["value"])
|
raise TypeError(f"For {name['value']}, Input type should b a tuple.")
|
||||||
|
|
||||||
out = {'shape': None,
|
out = {'shape': None,
|
||||||
'dtype': None,
|
'dtype': None,
|
||||||
'value': None}
|
'value': x["value"]}
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __call__(self, *args):
|
def __call__(self, x, name):
|
||||||
return
|
if context.get_context("mode") == 0:
|
||||||
|
return x["value"]
|
||||||
|
#Pynative mode
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
return x
|
||||||
|
raise TypeError(f"For {name['value']}, Input type should b a tuple.")
|
||||||
|
|
||||||
class CheckTensor(PrimitiveWithInfer):
|
class CheckTensor(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
@ -327,5 +334,5 @@ class CheckTensor(PrimitiveWithInfer):
|
||||||
'value': None}
|
'value': None}
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __call__(self, *args):
|
def __call__(self, x, name):
|
||||||
return
|
return
|
||||||
|
|
|
@ -18,7 +18,6 @@ from mindspore.ops import operations as P
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
|
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name, raise_none_error
|
||||||
from ._utils.utils import CheckTensor, CheckTuple
|
|
||||||
from ._utils.custom_ops import log_by_step
|
from ._utils.custom_ops import log_by_step
|
||||||
|
|
||||||
class Bernoulli(Distribution):
|
class Bernoulli(Distribution):
|
||||||
|
@ -125,9 +124,6 @@ class Bernoulli(Distribution):
|
||||||
self.sqrt = P.Sqrt()
|
self.sqrt = P.Sqrt()
|
||||||
self.uniform = C.uniform
|
self.uniform = C.uniform
|
||||||
|
|
||||||
self.checktensor = CheckTensor()
|
|
||||||
self.checktuple = CheckTuple()
|
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
str_info = f'probs = {self.probs}'
|
str_info = f'probs = {self.probs}'
|
||||||
|
@ -279,7 +275,7 @@ class Bernoulli(Distribution):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, shape is shape + batch_shape.
|
Tensor, shape is shape + batch_shape.
|
||||||
"""
|
"""
|
||||||
self.checktuple(shape, 'shape')
|
shape = self.checktuple(shape, 'shape')
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param(probs1)
|
||||||
origin_shape = shape + self.shape(probs1)
|
origin_shape = shape + self.shape(probs1)
|
||||||
if origin_shape == ():
|
if origin_shape == ():
|
||||||
|
|
|
@ -17,6 +17,7 @@ from mindspore.nn.cell import Cell
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from mindspore._checkparam import Rel
|
from mindspore._checkparam import Rel
|
||||||
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param
|
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param
|
||||||
|
from ._utils.utils import CheckTuple, CheckTensor
|
||||||
|
|
||||||
class Distribution(Cell):
|
class Distribution(Cell):
|
||||||
"""
|
"""
|
||||||
|
@ -79,6 +80,9 @@ class Distribution(Cell):
|
||||||
self._set_log_survival()
|
self._set_log_survival()
|
||||||
self._set_cross_entropy()
|
self._set_cross_entropy()
|
||||||
|
|
||||||
|
self.checktuple = CheckTuple()
|
||||||
|
self.checktensor = CheckTensor()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return self._name
|
return self._name
|
||||||
|
|
|
@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
|
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type, check_distribution_name,\
|
||||||
raise_none_error
|
raise_none_error
|
||||||
from ._utils.utils import CheckTensor, CheckTuple
|
|
||||||
from ._utils.custom_ops import log_by_step
|
from ._utils.custom_ops import log_by_step
|
||||||
|
|
||||||
class Exponential(Distribution):
|
class Exponential(Distribution):
|
||||||
|
@ -127,8 +126,6 @@ class Exponential(Distribution):
|
||||||
self.sq = P.Square()
|
self.sq = P.Square()
|
||||||
self.uniform = C.uniform
|
self.uniform = C.uniform
|
||||||
|
|
||||||
self.checktensor = CheckTensor()
|
|
||||||
self.checktuple = CheckTuple()
|
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
|
@ -270,7 +267,7 @@ class Exponential(Distribution):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, shape is shape + batch_shape.
|
Tensor, shape is shape + batch_shape.
|
||||||
"""
|
"""
|
||||||
self.checktuple(shape, 'shape')
|
shape = self.checktuple(shape, 'shape')
|
||||||
rate = self._check_param(rate)
|
rate = self._check_param(rate)
|
||||||
origin_shape = shape + self.shape(rate)
|
origin_shape = shape + self.shape(rate)
|
||||||
if origin_shape == ():
|
if origin_shape == ():
|
||||||
|
|
|
@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
|
from ._utils.utils import cast_to_tensor, check_prob, check_type, check_distribution_name,\
|
||||||
raise_none_error
|
raise_none_error
|
||||||
from ._utils.utils import CheckTensor, CheckTuple
|
|
||||||
from ._utils.custom_ops import log_by_step
|
from ._utils.custom_ops import log_by_step
|
||||||
|
|
||||||
class Geometric(Distribution):
|
class Geometric(Distribution):
|
||||||
|
@ -131,8 +130,6 @@ class Geometric(Distribution):
|
||||||
self.sqrt = P.Sqrt()
|
self.sqrt = P.Sqrt()
|
||||||
self.uniform = C.uniform
|
self.uniform = C.uniform
|
||||||
|
|
||||||
self.checktensor = CheckTensor()
|
|
||||||
self.checktuple = CheckTuple()
|
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
|
@ -278,7 +275,7 @@ class Geometric(Distribution):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, shape is shape + batch_shape.
|
Tensor, shape is shape + batch_shape.
|
||||||
"""
|
"""
|
||||||
self.checktuple(shape, 'shape')
|
shape = self.checktuple(shape, 'shape')
|
||||||
probs1 = self._check_param(probs1)
|
probs1 = self._check_param(probs1)
|
||||||
origin_shape = shape + self.shape(probs1)
|
origin_shape = shape + self.shape(probs1)
|
||||||
if origin_shape == ():
|
if origin_shape == ():
|
||||||
|
|
|
@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
|
from ._utils.utils import convert_to_batch, check_greater_zero, check_type, check_distribution_name,\
|
||||||
raise_none_error
|
raise_none_error
|
||||||
from ._utils.utils import CheckTensor, CheckTuple
|
|
||||||
from ._utils.custom_ops import log_by_step, expm1_by_step
|
from ._utils.custom_ops import log_by_step, expm1_by_step
|
||||||
|
|
||||||
class Normal(Distribution):
|
class Normal(Distribution):
|
||||||
|
@ -128,9 +127,6 @@ class Normal(Distribution):
|
||||||
self.sqrt = P.Sqrt()
|
self.sqrt = P.Sqrt()
|
||||||
self.zeroslike = P.ZerosLike()
|
self.zeroslike = P.ZerosLike()
|
||||||
|
|
||||||
self.checktensor = CheckTensor()
|
|
||||||
self.checktuple = CheckTuple()
|
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}'
|
str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}'
|
||||||
|
@ -277,7 +273,7 @@ class Normal(Distribution):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, shape is shape + batch_shape.
|
Tensor, shape is shape + batch_shape.
|
||||||
"""
|
"""
|
||||||
self.checktuple(shape, 'shape')
|
shape = self.checktuple(shape, 'shape')
|
||||||
mean, sd = self._check_param(mean, sd)
|
mean, sd = self._check_param(mean, sd)
|
||||||
batch_shape = self.shape(mean + sd)
|
batch_shape = self.shape(mean + sd)
|
||||||
origin_shape = shape + batch_shape
|
origin_shape = shape + batch_shape
|
||||||
|
|
|
@ -116,4 +116,4 @@ class TransformedDistribution(Distribution):
|
||||||
if not self.is_linear_transformation:
|
if not self.is_linear_transformation:
|
||||||
raise_not_impl_error("mean")
|
raise_not_impl_error("mean")
|
||||||
|
|
||||||
return self.bijector("forward", self.distribution("mean"))
|
return self.bijector("forward", self.distribution("mean", *args, **kwargs))
|
||||||
|
|
|
@ -19,7 +19,6 @@ from mindspore.common import dtype as mstype
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
|
from ._utils.utils import convert_to_batch, check_greater, check_type, check_distribution_name,\
|
||||||
raise_none_error
|
raise_none_error
|
||||||
from ._utils.utils import CheckTensor, CheckTuple
|
|
||||||
from ._utils.custom_ops import log_by_step
|
from ._utils.custom_ops import log_by_step
|
||||||
|
|
||||||
class Uniform(Distribution):
|
class Uniform(Distribution):
|
||||||
|
@ -131,9 +130,6 @@ class Uniform(Distribution):
|
||||||
self.zeroslike = P.ZerosLike()
|
self.zeroslike = P.ZerosLike()
|
||||||
self.uniform = C.uniform
|
self.uniform = C.uniform
|
||||||
|
|
||||||
self.checktensor = CheckTensor()
|
|
||||||
self.checktuple = CheckTuple()
|
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
if self.is_scalar_batch:
|
if self.is_scalar_batch:
|
||||||
str_info = f'low = {self.low}, high = {self.high}'
|
str_info = f'low = {self.low}, high = {self.high}'
|
||||||
|
@ -306,7 +302,7 @@ class Uniform(Distribution):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, shape is shape + batch_shape.
|
Tensor, shape is shape + batch_shape.
|
||||||
"""
|
"""
|
||||||
self.checktuple(shape, 'shape')
|
shape = self.checktuple(shape, 'shape')
|
||||||
low, high = self._check_param(low, high)
|
low, high = self._check_param(low, high)
|
||||||
broadcast_shape = self.shape(low + high)
|
broadcast_shape = self.shape(low + high)
|
||||||
origin_shape = shape + broadcast_shape
|
origin_shape = shape + broadcast_shape
|
||||||
|
|
Loading…
Reference in New Issue