forked from OSSInnovation/mindspore
!3345 Fix minor bugs in denoting and test cases
Merge pull request !3345 from peixu_ren/custom_gpu
This commit is contained in:
commit
2f1a5b979d
|
@ -27,7 +27,7 @@ from .clip_ops import clip_by_value
|
||||||
from .multitype_ops.add_impl import hyper_add
|
from .multitype_ops.add_impl import hyper_add
|
||||||
from .multitype_ops.ones_like_impl import ones_like
|
from .multitype_ops.ones_like_impl import ones_like
|
||||||
from .multitype_ops.zeros_like_impl import zeros_like
|
from .multitype_ops.zeros_like_impl import zeros_like
|
||||||
from .random_ops import normal
|
from .random_ops import set_seed, normal
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -48,5 +48,6 @@ __all__ = [
|
||||||
'zeros_like',
|
'zeros_like',
|
||||||
'ones_like',
|
'ones_like',
|
||||||
'zip_operation',
|
'zip_operation',
|
||||||
|
'set_seed',
|
||||||
'normal',
|
'normal',
|
||||||
'clip_by_value',]
|
'clip_by_value',]
|
||||||
|
|
|
@ -15,8 +15,11 @@
|
||||||
|
|
||||||
"""Operations for random number generatos."""
|
"""Operations for random number generatos."""
|
||||||
|
|
||||||
from mindspore.ops.primitive import constexpr
|
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
|
from .. import functional as F
|
||||||
|
from ..primitive import constexpr
|
||||||
|
from .multitype_ops import _constexpr_utils as const_utils
|
||||||
|
from ...common import dtype as mstype
|
||||||
|
|
||||||
# set graph-level RNG seed
|
# set graph-level RNG seed
|
||||||
_GRAPH_SEED = 0
|
_GRAPH_SEED = 0
|
||||||
|
@ -31,17 +34,17 @@ def get_seed():
|
||||||
return _GRAPH_SEED
|
return _GRAPH_SEED
|
||||||
|
|
||||||
|
|
||||||
def normal(shape, mean, stddev, seed):
|
def normal(shape, mean, stddev, seed=0):
|
||||||
"""
|
"""
|
||||||
Generates random numbers according to the Normal (or Gaussian) random number distribution.
|
Generates random numbers according to the Normal (or Gaussian) random number distribution.
|
||||||
It is defined as:
|
It is defined as:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- **shape** (tuple) - The shape of random tensor to be generated.
|
shape (tuple): The shape of random tensor to be generated.
|
||||||
- **mean** (Tensor) - The mean μ distribution parameter, which specifies the location of the peak.
|
mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak.
|
||||||
With float32 data type.
|
With float32 data type.
|
||||||
- **stddev** (Tensor) - The deviation σ distribution parameter. With float32 data type.
|
stddev (Tensor): The deviation σ distribution parameter. With float32 data type.
|
||||||
- **seed** (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
|
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
|
||||||
Default: 0.
|
Default: 0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -52,9 +55,13 @@ def normal(shape, mean, stddev, seed):
|
||||||
>>> shape = (4, 16)
|
>>> shape = (4, 16)
|
||||||
>>> mean = Tensor(1.0, mstype.float32)
|
>>> mean = Tensor(1.0, mstype.float32)
|
||||||
>>> stddev = Tensor(1.0, mstype.float32)
|
>>> stddev = Tensor(1.0, mstype.float32)
|
||||||
|
>>> C.set_seed(10)
|
||||||
>>> output = C.normal(shape, mean, stddev, seed=5)
|
>>> output = C.normal(shape, mean, stddev, seed=5)
|
||||||
"""
|
"""
|
||||||
set_seed(10)
|
mean_dtype = F.dtype(mean)
|
||||||
|
stddev_dtype = F.dtype(stddev)
|
||||||
|
const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "normal")
|
||||||
|
const_utils.check_tensors_dtype_same(stddev_dtype, mstype.float32, "normal")
|
||||||
seed1 = get_seed()
|
seed1 = get_seed()
|
||||||
seed2 = seed
|
seed2 = seed
|
||||||
stdnormal = P.StandardNormal(seed1, seed2)
|
stdnormal = P.StandardNormal(seed1, seed2)
|
||||||
|
|
|
@ -29,7 +29,7 @@ class Net(nn.Cell):
|
||||||
self.stdnormal = P.StandardNormal(seed, seed2)
|
self.stdnormal = P.StandardNormal(seed, seed2)
|
||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
return self.stdnormal(self.shape, self.seed, self.seed2)
|
return self.stdnormal(self.shape)
|
||||||
|
|
||||||
|
|
||||||
def test_net():
|
def test_net():
|
||||||
|
|
|
@ -29,7 +29,7 @@ class Net(nn.Cell):
|
||||||
self.stdnormal = P.StandardNormal(seed, seed2)
|
self.stdnormal = P.StandardNormal(seed, seed2)
|
||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
return self.stdnormal(self.shape, self.seed, self.seed2)
|
return self.stdnormal(self.shape)
|
||||||
|
|
||||||
|
|
||||||
def test_net():
|
def test_net():
|
||||||
|
|
Loading…
Reference in New Issue