forked from mindspore-Ecosystem/mindspore
added some parameter checking
This commit is contained in:
parent
9ad82f79fd
commit
415dad3adb
|
@ -17,11 +17,14 @@ Distribution operation utility functions.
|
|||
"""
|
||||
from .utils import *
|
||||
|
||||
__all__ = ['convert_to_batch',
|
||||
'cast_to_tensor',
|
||||
'check_greater',
|
||||
'check_greater_equal_zero',
|
||||
'check_greater_zero',
|
||||
'calc_broadcast_shape_from_param',
|
||||
'check_scalar_from_param',
|
||||
'check_prob']
|
||||
__all__ = [
|
||||
'convert_to_batch',
|
||||
'cast_to_tensor',
|
||||
'check_greater',
|
||||
'check_greater_equal_zero',
|
||||
'check_greater_zero',
|
||||
'calc_broadcast_shape_from_param',
|
||||
'check_scalar_from_param',
|
||||
'check_prob',
|
||||
'check_type',
|
||||
]
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -23,7 +22,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import composite as C
|
||||
import mindspore.nn as nn
|
||||
|
||||
def cast_to_tensor(t, dtype=mstype.float32):
|
||||
def cast_to_tensor(t, hint_dtype=mstype.float32):
|
||||
"""
|
||||
Cast an user input value into a Tensor of dtype.
|
||||
If the input t is of type Parameter, t is directly returned as a Parameter.
|
||||
|
@ -41,25 +40,26 @@ def cast_to_tensor(t, dtype=mstype.float32):
|
|||
if isinstance(t, Parameter):
|
||||
return t
|
||||
if isinstance(t, Tensor):
|
||||
if t.dtype != hint_dtype:
|
||||
raise TypeError(f"Input tensor should be type {hint_dtype}.")
|
||||
#check if the Tensor in shape of Tensor(4)
|
||||
if t.dim() == 0:
|
||||
value = t.asnumpy()
|
||||
return Tensor([t], dtype=dtype)
|
||||
return Tensor([value], dtype=hint_dtype)
|
||||
#convert the type of tensor to dtype
|
||||
t.set_dtype(dtype)
|
||||
return t
|
||||
if isinstance(t, (list, np.ndarray)):
|
||||
return Tensor(t, dtype=dtype)
|
||||
return Tensor(t, dtype=hint_dtype)
|
||||
if np.isscalar(t):
|
||||
return Tensor([t], dtype=dtype)
|
||||
return Tensor([t], dtype=hint_dtype)
|
||||
raise RuntimeError("Input type is not supported.")
|
||||
|
||||
def convert_to_batch(t, batch_shape, dtype):
|
||||
def convert_to_batch(t, batch_shape, hint_dtype):
|
||||
"""
|
||||
Convert a Tensor to a given batch shape.
|
||||
|
||||
Args:
|
||||
t (Tensor, Parameter): Tensor to be converted.
|
||||
t (int, float, list, numpy.ndarray, Tensor, Parameter): Tensor to be converted.
|
||||
batch_shape (tuple): desired batch shape.
|
||||
dtype (mindspore.dtype): desired dtype.
|
||||
|
||||
|
@ -71,9 +71,8 @@ def convert_to_batch(t, batch_shape, dtype):
|
|||
"""
|
||||
if isinstance(t, Parameter):
|
||||
return t
|
||||
if isinstance(t, Tensor):
|
||||
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=dtype)
|
||||
return Tensor(np.broadcast_to(t, batch_shape), dtype=dtype)
|
||||
t = cast_to_tensor(t, hint_dtype)
|
||||
return Tensor(np.broadcast_to(t.asnumpy(), batch_shape), dtype=hint_dtype)
|
||||
|
||||
def check_scalar_from_param(params):
|
||||
"""
|
||||
|
@ -85,6 +84,8 @@ def check_scalar_from_param(params):
|
|||
Notes: String parameters are excluded.
|
||||
"""
|
||||
for value in params.values():
|
||||
if isinstance(value, (nn.probability.bijector.Bijector, nn.probability.distribution.Distribution)):
|
||||
return params['distribution'].is_scalar_batch
|
||||
if isinstance(value, Parameter):
|
||||
return False
|
||||
if isinstance(value, (str, type(params['dtype']))):
|
||||
|
@ -108,6 +109,8 @@ def calc_broadcast_shape_from_param(params):
|
|||
"""
|
||||
broadcast_shape = []
|
||||
for value in params.values():
|
||||
if isinstance(value, (nn.probability.bijector.Bijector, nn.probability.distribution.Distribution)):
|
||||
return params['distribution'].broadcast_shape
|
||||
if isinstance(value, (str, type(params['dtype']))):
|
||||
continue
|
||||
if value is None:
|
||||
|
@ -251,3 +254,7 @@ def check_tensor_type(name, inputs, valid_type):
|
|||
inputs = P.DType()(inputs)
|
||||
if inputs not in valid_type:
|
||||
raise TypeError(f"{name} dtype is invalid")
|
||||
|
||||
def check_type(data_type, value_type, name):
|
||||
if not data_type in value_type:
|
||||
raise TypeError(f"For {name}, valid type include {value_type}, {data_type} is invalid")
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_prob
|
||||
from ._utils.utils import cast_to_tensor, check_prob, check_type
|
||||
|
||||
class Bernoulli(Distribution):
|
||||
"""
|
||||
|
@ -95,13 +95,14 @@ class Bernoulli(Distribution):
|
|||
Constructor of Bernoulli distribution.
|
||||
"""
|
||||
param = dict(locals())
|
||||
super(Bernoulli, self).__init__(dtype, name, param)
|
||||
valid_dtype = mstype.int_type + mstype.uint_type
|
||||
check_type(dtype, valid_dtype, "Bernoulli")
|
||||
super(Bernoulli, self).__init__(seed, dtype, name, param)
|
||||
if probs is not None:
|
||||
self._probs = cast_to_tensor(probs, dtype=mstype.float32)
|
||||
self._probs = cast_to_tensor(probs, hint_dtype=mstype.float32)
|
||||
check_prob(self.probs)
|
||||
else:
|
||||
self._probs = probs
|
||||
self.seed = seed
|
||||
|
||||
# ops needed for the class
|
||||
self.cast = P.Cast()
|
||||
|
@ -231,8 +232,8 @@ class Bernoulli(Distribution):
|
|||
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
|
||||
|
||||
.. math::
|
||||
KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) +
|
||||
probs0_a * \log(\fract{probs0_a}{probs0_b})
|
||||
KL(a||b) = probs1_a * \log(\frac{probs1_a}{probs1_b}) +
|
||||
probs0_a * \log(\frac{probs0_a}{probs0_b})
|
||||
"""
|
||||
if dist == 'Bernoulli':
|
||||
probs1_a = self.probs if probs1_a is None else probs1_a
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
"""basic"""
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from ._utils.utils import calc_broadcast_shape_from_param, check_scalar_from_param
|
||||
|
||||
class Distribution(Cell):
|
||||
|
@ -38,6 +39,7 @@ class Distribution(Cell):
|
|||
original distribuion.
|
||||
"""
|
||||
def __init__(self,
|
||||
seed,
|
||||
dtype,
|
||||
name,
|
||||
param):
|
||||
|
@ -46,7 +48,11 @@ class Distribution(Cell):
|
|||
Constructor of distribution class.
|
||||
"""
|
||||
super(Distribution, self).__init__()
|
||||
validator.check_value_type('name', name, [str], 'distribution_name')
|
||||
validator.check_value_type('seed', seed, [int], name)
|
||||
|
||||
self._name = name
|
||||
self._seed = seed
|
||||
self._dtype = dtype
|
||||
self._parameters = {}
|
||||
# parsing parameters
|
||||
|
@ -77,6 +83,10 @@ class Distribution(Cell):
|
|||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def seed(self):
|
||||
return self._seed
|
||||
|
||||
@property
|
||||
def parameters(self):
|
||||
return self._parameters
|
||||
|
@ -85,6 +95,10 @@ class Distribution(Cell):
|
|||
def is_scalar_batch(self):
|
||||
return self._is_scalar_batch
|
||||
|
||||
@property
|
||||
def broadcast_shape(self):
|
||||
return self._broadcast_shape
|
||||
|
||||
def _set_prob(self):
|
||||
"""
|
||||
Set probability funtion based on the availability of _prob and _log_likehood.
|
||||
|
|
|
@ -17,7 +17,7 @@ import numpy as np
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_greater_zero
|
||||
from ._utils.utils import cast_to_tensor, check_greater_zero, check_type
|
||||
|
||||
class Exponential(Distribution):
|
||||
"""
|
||||
|
@ -96,9 +96,11 @@ class Exponential(Distribution):
|
|||
Constructor of Exponential distribution.
|
||||
"""
|
||||
param = dict(locals())
|
||||
super(Exponential, self).__init__(dtype, name, param)
|
||||
valid_dtype = mstype.float_type
|
||||
check_type(dtype, valid_dtype, "Exponential")
|
||||
super(Exponential, self).__init__(seed, dtype, name, param)
|
||||
if rate is not None:
|
||||
self._rate = cast_to_tensor(rate, mstype.float32)
|
||||
self._rate = cast_to_tensor(rate, dtype)
|
||||
check_greater_zero(self._rate, "rate")
|
||||
else:
|
||||
self._rate = rate
|
||||
|
@ -135,7 +137,7 @@ class Exponential(Distribution):
|
|||
def _mean(self, rate=None):
|
||||
r"""
|
||||
.. math::
|
||||
MEAN(EXP) = \fract{1.0}{\lambda}.
|
||||
MEAN(EXP) = \frac{1.0}{\lambda}.
|
||||
"""
|
||||
rate = self.rate if rate is None else rate
|
||||
return 1.0 / rate
|
||||
|
@ -152,7 +154,7 @@ class Exponential(Distribution):
|
|||
def _sd(self, rate=None):
|
||||
r"""
|
||||
.. math::
|
||||
sd(EXP) = \fract{1.0}{\lambda}.
|
||||
sd(EXP) = \frac{1.0}{\lambda}.
|
||||
"""
|
||||
rate = self.rate if rate is None else rate
|
||||
return 1.0 / rate
|
||||
|
|
|
@ -17,7 +17,7 @@ import numpy as np
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import cast_to_tensor, check_prob
|
||||
from ._utils.utils import cast_to_tensor, check_prob, check_type
|
||||
|
||||
class Geometric(Distribution):
|
||||
"""
|
||||
|
@ -97,9 +97,11 @@ class Geometric(Distribution):
|
|||
Constructor of Geometric distribution.
|
||||
"""
|
||||
param = dict(locals())
|
||||
super(Geometric, self).__init__(dtype, name, param)
|
||||
valid_dtype = mstype.int_type + mstype.uint_type
|
||||
check_type(dtype, valid_dtype, "Geometric")
|
||||
super(Geometric, self).__init__(seed, dtype, name, param)
|
||||
if probs is not None:
|
||||
self._probs = cast_to_tensor(probs, dtype=mstype.float32)
|
||||
self._probs = cast_to_tensor(probs, hint_dtype=mstype.float32)
|
||||
check_prob(self._probs)
|
||||
else:
|
||||
self._probs = probs
|
||||
|
@ -154,7 +156,7 @@ class Geometric(Distribution):
|
|||
def _var(self, probs1=None):
|
||||
r"""
|
||||
.. math::
|
||||
VAR(Geo) = \fract{1 - probs1}{probs1 ^ {2}}
|
||||
VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}}
|
||||
"""
|
||||
probs1 = self.probs if probs1 is None else probs1
|
||||
return (1.0 - probs1) / self.sq(probs1)
|
||||
|
@ -162,7 +164,7 @@ class Geometric(Distribution):
|
|||
def _entropy(self, probs=None):
|
||||
r"""
|
||||
.. math::
|
||||
H(Geo) = \fract{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
|
||||
H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
|
||||
"""
|
||||
probs1 = self.probs if probs is None else probs
|
||||
probs0 = 1.0 - probs1
|
||||
|
@ -244,7 +246,7 @@ class Geometric(Distribution):
|
|||
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
|
||||
|
||||
.. math::
|
||||
KL(a||b) = \log(\fract{probs1_a}{probs1_b}) + \fract{probs0_a}{probs1_a} * \log(\fract{probs0_a}{probs0_b})
|
||||
KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b})
|
||||
"""
|
||||
if dist == 'Geometric':
|
||||
probs1_a = self.probs if probs1_a is None else probs1_a
|
||||
|
|
|
@ -18,7 +18,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import convert_to_batch, check_greater_equal_zero
|
||||
from ._utils.utils import convert_to_batch, check_greater_equal_zero, check_type
|
||||
|
||||
|
||||
class Normal(Distribution):
|
||||
|
@ -100,15 +100,17 @@ class Normal(Distribution):
|
|||
Constructor of normal distribution.
|
||||
"""
|
||||
param = dict(locals())
|
||||
super(Normal, self).__init__(dtype, name, param)
|
||||
valid_dtype = mstype.float_type
|
||||
check_type(dtype, valid_dtype, "Normal")
|
||||
super(Normal, self).__init__(seed, dtype, name, param)
|
||||
if mean is not None and sd is not None:
|
||||
self._mean_value = convert_to_batch(mean, self._broadcast_shape, dtype)
|
||||
self._sd_value = convert_to_batch(sd, self._broadcast_shape, dtype)
|
||||
self._mean_value = convert_to_batch(mean, self.broadcast_shape, dtype)
|
||||
self._sd_value = convert_to_batch(sd, self.broadcast_shape, dtype)
|
||||
check_greater_equal_zero(self._sd_value, "Standard deviation")
|
||||
else:
|
||||
self._mean_value = mean
|
||||
self._sd_value = sd
|
||||
self.seed = seed
|
||||
|
||||
|
||||
#ops needed for the class
|
||||
self.const = P.ScalarToArray()
|
||||
|
@ -191,7 +193,7 @@ class Normal(Distribution):
|
|||
sd (Tensor): standard deviation the distribution. Default: self._sd_value.
|
||||
|
||||
.. math::
|
||||
L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
|
||||
L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
|
||||
"""
|
||||
mean = self._mean_value if mean is None else mean
|
||||
sd = self._sd_value if sd is None else sd
|
||||
|
@ -229,7 +231,7 @@ class Normal(Distribution):
|
|||
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
|
||||
|
||||
.. math::
|
||||
KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
|
||||
KL(a||b) = 0.5 * (\frac{MEAN(a)}{STD(b)} - \frac{MEAN(b)}{STD(b)}) ^ 2 +
|
||||
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
|
||||
"""
|
||||
if dist == 'Normal':
|
||||
|
|
|
@ -14,7 +14,11 @@
|
|||
# ============================================================================
|
||||
"""Transformed Distribution"""
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import check_type
|
||||
|
||||
class TransformedDistribution(Distribution):
|
||||
"""
|
||||
|
@ -35,12 +39,19 @@ class TransformedDistribution(Distribution):
|
|||
def __init__(self,
|
||||
bijector,
|
||||
distribution,
|
||||
dtype,
|
||||
seed=0,
|
||||
name="transformed_distribution"):
|
||||
"""
|
||||
Constructor of transformed_distribution class.
|
||||
"""
|
||||
param = dict(locals())
|
||||
super(TransformedDistribution, self).__init__(distribution.dtype, name, param)
|
||||
validator.check_value_type('bijector', bijector, [nn.probability.bijector.Bijector], name)
|
||||
validator.check_value_type('distribution', distribution, [Distribution], name)
|
||||
valid_dtype = mstype.number_type
|
||||
check_type(dtype, valid_dtype, "transformed_distribution")
|
||||
super(TransformedDistribution, self).__init__(seed, dtype, name, param)
|
||||
|
||||
self._bijector = bijector
|
||||
self._distribution = distribution
|
||||
self._is_linear_transformation = bijector.is_constant_jacobian
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import convert_to_batch, check_greater
|
||||
from ._utils.utils import convert_to_batch, check_greater, check_type
|
||||
|
||||
class Uniform(Distribution):
|
||||
"""
|
||||
|
@ -97,10 +97,12 @@ class Uniform(Distribution):
|
|||
Constructor of Uniform distribution.
|
||||
"""
|
||||
param = dict(locals())
|
||||
super(Uniform, self).__init__(dtype, name, param)
|
||||
valid_dtype = mstype.float_type
|
||||
check_type(dtype, valid_dtype, "Uniform")
|
||||
super(Uniform, self).__init__(seed, dtype, name, param)
|
||||
if low is not None and high is not None:
|
||||
self._low = convert_to_batch(low, self._broadcast_shape, dtype)
|
||||
self._high = convert_to_batch(high, self._broadcast_shape, dtype)
|
||||
self._low = convert_to_batch(low, self.broadcast_shape, dtype)
|
||||
self._high = convert_to_batch(high, self.broadcast_shape, dtype)
|
||||
check_greater(self.low, self.high, "low value", "high value")
|
||||
else:
|
||||
self._low = low
|
||||
|
@ -156,7 +158,7 @@ class Uniform(Distribution):
|
|||
def _mean(self, low=None, high=None):
|
||||
r"""
|
||||
.. math::
|
||||
MEAN(U) = \fract{low + high}{2}.
|
||||
MEAN(U) = \frac{low + high}{2}.
|
||||
"""
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
|
@ -166,7 +168,7 @@ class Uniform(Distribution):
|
|||
def _var(self, low=None, high=None):
|
||||
r"""
|
||||
.. math::
|
||||
VAR(U) = \fract{(high -low) ^ 2}{12}.
|
||||
VAR(U) = \frac{(high -low) ^ 2}{12}.
|
||||
"""
|
||||
low = self.low if low is None else low
|
||||
high = self.high if high is None else high
|
||||
|
@ -207,7 +209,7 @@ class Uniform(Distribution):
|
|||
|
||||
.. math::
|
||||
pdf(x) = 0 if x < low;
|
||||
pdf(x) = \fract{1.0}{high -low} if low <= x <= high;
|
||||
pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
|
||||
pdf(x) = 0 if x > high;
|
||||
"""
|
||||
low = self.low if low is None else low
|
||||
|
@ -251,7 +253,7 @@ class Uniform(Distribution):
|
|||
|
||||
.. math::
|
||||
cdf(x) = 0 if x < low;
|
||||
cdf(x) = \fract{x - low}{high -low} if low <= x <= high;
|
||||
cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
|
||||
cdf(x) = 1 if x > high;
|
||||
"""
|
||||
low = self.low if low is None else low
|
||||
|
|
|
@ -31,6 +31,18 @@ def test_arguments():
|
|||
b = msd.Bernoulli([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32)
|
||||
assert isinstance(b, msd.Distribution)
|
||||
|
||||
def test_type():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Bernoulli([0.1], dtype=dtype.float32)
|
||||
|
||||
def test_name():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Bernoulli([0.1], name=1.0)
|
||||
|
||||
def test_seed():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Bernoulli([0.1], seed='seed')
|
||||
|
||||
def test_prob():
|
||||
"""
|
||||
Invalid probability.
|
||||
|
|
|
@ -32,6 +32,18 @@ def test_arguments():
|
|||
e = msd.Exponential([0.1, 0.3, 0.5, 1.0], dtype=dtype.float32)
|
||||
assert isinstance(e, msd.Distribution)
|
||||
|
||||
def test_type():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Exponential([0.1], dtype=dtype.int32)
|
||||
|
||||
def test_name():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Exponential([0.1], name=1.0)
|
||||
|
||||
def test_seed():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Exponential([0.1], seed='seed')
|
||||
|
||||
def test_rate():
|
||||
"""
|
||||
Invalid rate.
|
||||
|
|
|
@ -32,6 +32,18 @@ def test_arguments():
|
|||
g = msd.Geometric([0.0, 0.3, 0.5, 1.0], dtype=dtype.int32)
|
||||
assert isinstance(g, msd.Distribution)
|
||||
|
||||
def test_type():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Geometric([0.1], dtype=dtype.float32)
|
||||
|
||||
def test_name():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Geometric([0.1], name=1.0)
|
||||
|
||||
def test_seed():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Geometric([0.1], seed='seed')
|
||||
|
||||
def test_prob():
|
||||
"""
|
||||
Invalid probability.
|
||||
|
|
|
@ -30,6 +30,17 @@ def test_normal_shape_errpr():
|
|||
with pytest.raises(ValueError):
|
||||
msd.Normal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
|
||||
|
||||
def test_type():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Normal(0., 1., dtype=dtype.int32)
|
||||
|
||||
def test_name():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Normal(0., 1., name=1.0)
|
||||
|
||||
def test_seed():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Normal(0., 1., seed='seed')
|
||||
|
||||
def test_arguments():
|
||||
"""
|
||||
|
|
|
@ -30,6 +30,17 @@ def test_uniform_shape_errpr():
|
|||
with pytest.raises(ValueError):
|
||||
msd.Uniform([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
|
||||
|
||||
def test_type():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Uniform(0., 1., dtype=dtype.int32)
|
||||
|
||||
def test_name():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Uniform(0., 1., name=1.0)
|
||||
|
||||
def test_seed():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Uniform(0., 1., seed='seed')
|
||||
|
||||
def test_arguments():
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue