forked from mindspore-Ecosystem/mindspore
add python implement and st for random gamma
This commit is contained in:
parent
45241fa3a6
commit
d0afaf60f0
|
@ -13,7 +13,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Operations for random number generators."""
|
"""Operations for random number generators."""
|
||||||
|
import numpy as np
|
||||||
from mindspore.ops.primitive import constexpr
|
from mindspore.ops.primitive import constexpr
|
||||||
|
from ...common.tensor import Tensor
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .multitype_ops import _constexpr_utils as const_utils
|
from .multitype_ops import _constexpr_utils as const_utils
|
||||||
|
@ -234,7 +236,7 @@ def gamma(shape, alpha, beta, seed=None):
|
||||||
TypeError: If dtype of `alpha` and `beta` is not float32.
|
TypeError: If dtype of `alpha` and `beta` is not float32.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend``
|
``Ascend`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from mindspore import Tensor, ops
|
>>> from mindspore import Tensor, ops
|
||||||
|
@ -246,45 +248,41 @@ def gamma(shape, alpha, beta, seed=None):
|
||||||
>>> output = ops.gamma(shape, alpha, beta, seed=5)
|
>>> output = ops.gamma(shape, alpha, beta, seed=5)
|
||||||
>>> result = output.shape
|
>>> result = output.shape
|
||||||
>>> print(result)
|
>>> print(result)
|
||||||
(3, 2, 2)
|
(3, 1, 2, 2, 2)
|
||||||
>>> # case 2: alpha_shape is (2, 3), so shape is (3, 1, 3)
|
>>> # case 2: alpha_shape is (2), so shape is (7, 5, 2)
|
||||||
>>> shape = (3, 1, 3)
|
>>> shape = (7, 5)
|
||||||
>>> alpha = Tensor(np.array([[1, 3, 4], [2, 5, 6]]), mindspore.float32)
|
>>> alpha = Tensor(np.array([0.5, 1.5]), mindspore.float32)
|
||||||
>>> beta = Tensor(np.array([1.0]), mindspore.float32)
|
>>> beta = Tensor(np.array([1.0]), mindspore.float32)
|
||||||
>>> output = ops.gamma(shape, alpha, beta, seed=5)
|
>>> output = ops.gamma(shape, alpha, beta, seed=5)
|
||||||
>>> result = output.shape
|
>>> result = output.shape
|
||||||
>>> print(result)
|
>>> print(result)
|
||||||
(3, 2, 3)
|
(7, 5, 2)
|
||||||
>>> # case 3: beta_shape is (1, 2), the output is different.
|
>>> # case 3: beta_shape is (1, 2), the output is different.
|
||||||
>>> shape = (3, 1, 2)
|
>>> shape = (3, 1, 2)
|
||||||
>>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
|
>>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
|
||||||
>>> beta = Tensor(np.array([1.0, 2]), mindspore.float32)
|
>>> beta = Tensor(np.array([3.0, 2.0]), mindspore.float32)
|
||||||
>>> output = ops.gamma(shape, alpha, beta, seed=5)
|
>>> output = ops.gamma(shape, alpha, beta, seed=3)
|
||||||
>>> result = output.shape
|
>>> result = output.shape
|
||||||
>>> print(output)
|
>>> print(output)
|
||||||
[[[ 2.2132034 5.8855834]]
|
[[[[[0.8373873 1.4698703 ]
|
||||||
[ 3.3981476 7.5805717]
|
[1.0850314 3.487788 ]]
|
||||||
[[ 3.3981476 7.5805717]]
|
[[0.57389003 1.8903136 ]
|
||||||
[ 3.7190282 19.941492]
|
[1.2278512 1.3656161 ]]]]
|
||||||
[[ 2.9512358 2.5969937]]
|
[[[[0.12379696 1.9381095 ]
|
||||||
[ 3.786061 5.160872 ]]]
|
[1.3704795 3.5111923 ]]
|
||||||
>>> # case 4: beta_shape is (2, 1), the output is different.
|
[[0.49400368 1.9125801 ]
|
||||||
>>> shape = (3, 1, 2)
|
[0.94508415 2.0883005 ]]]]
|
||||||
>>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
|
[[[[0.5898374 1.1703413 ]
|
||||||
>>> beta = Tensor(np.array([[1.0], [2.0]]), mindspore.float32)
|
[1.4078385 0.8582265 ]]
|
||||||
>>> output = ops.gamma(shape, alpha, beta, seed=5)
|
[[0.5685522 1.4178807 ]
|
||||||
>>> result = output.shape
|
[1.5442697 3.6673684 ]]]]]
|
||||||
>>> print(output)
|
|
||||||
[[[ 5.6085486 7.8280783]]
|
|
||||||
[ 15.97684 16.116285]
|
|
||||||
[[ 1.8347423 1.713663]]
|
|
||||||
[ 3.2434065 15.667398]
|
|
||||||
[[ 4.2922077 7.3365674]]
|
|
||||||
[ 5.3876944 13.159832 ]]]
|
|
||||||
"""
|
"""
|
||||||
seed1, seed2 = _get_seed(seed, "gamma")
|
seed1, seed2 = _get_seed(seed, "gamma")
|
||||||
random_gamma = P.Gamma(seed1, seed2)
|
random_gamma = P.Gamma(seed1, seed2)
|
||||||
value = random_gamma(shape, alpha, beta)
|
alpha_type = F.dtype(alpha)
|
||||||
|
if beta is None:
|
||||||
|
beta = Tensor(np.array([1.0]), alpha_type)
|
||||||
|
value = random_gamma(shape, alpha, beta) / beta
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -288,7 +288,7 @@ class Gamma(PrimitiveWithInfer):
|
||||||
ValueError: If `shape` is not a constant value.
|
ValueError: If `shape` is not a constant value.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend``
|
``Ascend`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> shape = (3, 1, 2)
|
>>> shape = (3, 1, 2)
|
||||||
|
@ -298,7 +298,7 @@ class Gamma(PrimitiveWithInfer):
|
||||||
>>> output = gamma(shape, alpha, beta)
|
>>> output = gamma(shape, alpha, beta)
|
||||||
>>> result = output.shape
|
>>> result = output.shape
|
||||||
>>> print(result)
|
>>> print(result)
|
||||||
(3, 2, 2)
|
(3, 1, 2, 2, 2)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
|
@ -320,10 +320,11 @@ class Gamma(PrimitiveWithInfer):
|
||||||
Validator.check_tensor_dtype_valid("beta", beta["dtype"], [mstype.float32], self.name)
|
Validator.check_tensor_dtype_valid("beta", beta["dtype"], [mstype.float32], self.name)
|
||||||
broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name,
|
broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name,
|
||||||
arg_name1="alpha", arg_name2="beta")
|
arg_name1="alpha", arg_name2="beta")
|
||||||
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name,
|
out_shape = list(shape_v)
|
||||||
arg_name1="broadcast_alpha_beta", arg_name2="shape")
|
out_shape.extend(broadcast_shape)
|
||||||
|
|
||||||
out = {
|
out = {
|
||||||
'shape': broadcast_shape,
|
'shape': out_shape,
|
||||||
'dtype': mstype.float32,
|
'dtype': mstype.float32,
|
||||||
'value': None}
|
'value': None}
|
||||||
return out
|
return out
|
||||||
|
@ -447,7 +448,6 @@ class RandomPoisson(Primitive):
|
||||||
Validator.check_type_name("dtype", dtype, valid_values, self.name)
|
Validator.check_type_name("dtype", dtype, valid_values, self.name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class UniformInt(PrimitiveWithInfer):
|
class UniformInt(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
Produces random integer values i, uniformly distributed on the closed interval [minval, maxval), that is,
|
Produces random integer values i, uniformly distributed on the closed interval [minval, maxval), that is,
|
||||||
|
|
|
@ -47,9 +47,9 @@ def test_net_1D():
|
||||||
def test_net_ND():
|
def test_net_ND():
|
||||||
seed = 10
|
seed = 10
|
||||||
shape = (3, 1, 2)
|
shape = (3, 1, 2)
|
||||||
alpha = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
|
alpha = np.array([[1, 2], [3, 4]]).astype(np.float32)
|
||||||
beta = np.array([1.0]).astype(np.float32)
|
beta = np.array([1.0]).astype(np.float32)
|
||||||
net = Net(shape=shape, seed=seed)
|
net = Net(shape=shape, seed=seed)
|
||||||
talpha, tbeta = Tensor(alpha), Tensor(beta)
|
talpha, tbeta = Tensor(alpha), Tensor(beta)
|
||||||
output = net(talpha, tbeta)
|
output = net(talpha, tbeta)
|
||||||
assert output.shape == (3, 2, 2)
|
assert output.shape == (3, 1, 2, 2, 2)
|
||||||
|
|
|
@ -24,6 +24,7 @@ from mindspore.common import set_seed
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
set_seed(20)
|
set_seed(20)
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self, shape, seed=0):
|
def __init__(self, shape, seed=0):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
@ -48,9 +49,9 @@ def test_net_1D():
|
||||||
def test_net_ND():
|
def test_net_ND():
|
||||||
seed = 10
|
seed = 10
|
||||||
shape = (3, 1, 2)
|
shape = (3, 1, 2)
|
||||||
alpha = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
|
alpha = np.array([[1, 2], [3, 4]]).astype(np.float32)
|
||||||
beta = np.array([1.0]).astype(np.float32)
|
beta = np.array([1.0]).astype(np.float32)
|
||||||
net = Net(shape, seed)
|
net = Net(shape, seed)
|
||||||
talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32)
|
talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32)
|
||||||
output = net(talpha, tbeta)
|
output = net(talpha, tbeta)
|
||||||
assert output.shape == (3, 2, 2)
|
assert output.shape == (3, 1, 2, 2, 2)
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class RandomGammaTEST(nn.Cell):
|
||||||
|
def __init__(self, seed=0):
|
||||||
|
super(RandomGammaTEST, self).__init__()
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def construct(self, shape, alpha, beta):
|
||||||
|
return C.gamma(shape, alpha, beta, self.seed)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.parametrize("dtype", [np.float64, np.float32, np.float16])
|
||||||
|
def test_gamma_op(dtype):
|
||||||
|
"""
|
||||||
|
Feature: Gamma cpu kernel
|
||||||
|
Description: test the gamma beta is a tensor.
|
||||||
|
Expectation: match to tensorflow benchmark.
|
||||||
|
"""
|
||||||
|
|
||||||
|
shape = (3, 1, 2)
|
||||||
|
alpha = Tensor(np.array([[3, 4], [5, 6]]), ms.float32)
|
||||||
|
beta = Tensor(np.array([3.0, 2.0]), ms.float32)
|
||||||
|
gamma_test = RandomGammaTEST(seed=3)
|
||||||
|
expect = np.array([3, 1, 2, 2, 2])
|
||||||
|
|
||||||
|
ms.set_context(mode=ms.GRAPH_MODE, device_target='CPU')
|
||||||
|
output = gamma_test(shape, alpha, beta)
|
||||||
|
assert (output.shape == expect).all()
|
||||||
|
|
||||||
|
ms.set_context(mode=ms.PYNATIVE_MODE)
|
||||||
|
output = ms.ops.gamma(shape, alpha, beta)
|
||||||
|
assert (output.shape == expect).all()
|
||||||
|
|
||||||
|
ms.set_context(mode=ms.GRAPH_MODE, device_target='CPU')
|
||||||
|
output = gamma_test(shape, alpha, None)
|
||||||
|
assert (output.shape == expect).all()
|
||||||
|
|
||||||
|
ms.set_context(mode=ms.PYNATIVE_MODE)
|
||||||
|
output = ms.ops.gamma(shape, alpha, None)
|
||||||
|
assert (output.shape == expect).all()
|
|
@ -1,94 +0,0 @@
|
||||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
|
||||||
from mindspore import Tensor, context
|
|
||||||
from mindspore.nn import Cell
|
|
||||||
from mindspore.ops import operations as P
|
|
||||||
|
|
||||||
from parallel.utils.utils import ParallelValidator, compile_net
|
|
||||||
|
|
||||||
SEED_ = 1
|
|
||||||
SEED2_ = 1
|
|
||||||
alpha_ = Tensor(np.array([1.0]), mstype.float32)
|
|
||||||
beta_ = Tensor(np.array([1.0]), mstype.float32)
|
|
||||||
|
|
||||||
|
|
||||||
class Net(Cell):
|
|
||||||
def __init__(self, seed, seed2, strategy=None):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.gamma = P.Gamma(seed, seed2).shard(strategy)
|
|
||||||
|
|
||||||
def construct(self, shape, alpha, beta):
|
|
||||||
out = self.gamma(shape, alpha, beta)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def test_gamma_auto_parallel():
|
|
||||||
"""
|
|
||||||
Features: test Gamma auto parallel
|
|
||||||
Description: auto parallel
|
|
||||||
Expectation: compile success
|
|
||||||
"""
|
|
||||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
|
|
||||||
net = Net(SEED_, SEED2_)
|
|
||||||
shape = (4, 4, 4)
|
|
||||||
compile_net(net, shape, alpha_, beta_)
|
|
||||||
|
|
||||||
|
|
||||||
def test_gamma_data_parallel():
|
|
||||||
"""
|
|
||||||
Features: test Gamma data parallel
|
|
||||||
Description: data parallel
|
|
||||||
Expectation: compile success
|
|
||||||
"""
|
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
|
|
||||||
net = Net(SEED_, SEED2_)
|
|
||||||
shape = (8, 8)
|
|
||||||
phase = compile_net(net, shape, alpha_, beta_)
|
|
||||||
|
|
||||||
validator = ParallelValidator(net, phase)
|
|
||||||
assert validator.check_node_attrs("Gamma-0", {"seed": 2, "seed2": 2})
|
|
||||||
|
|
||||||
|
|
||||||
def test_gamma_model_parallel():
|
|
||||||
"""
|
|
||||||
Features: test Gamma model parallel
|
|
||||||
Description: model parallel
|
|
||||||
Expectation: compile success
|
|
||||||
"""
|
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
|
|
||||||
shape = (8, 8)
|
|
||||||
strategy = ((2, 2), (1,), (1,))
|
|
||||||
net = Net(SEED_, SEED2_, strategy)
|
|
||||||
phase = compile_net(net, shape, alpha_, beta_)
|
|
||||||
validator = ParallelValidator(net, phase)
|
|
||||||
assert validator.check_node_attrs("Gamma-0", {"seed": 3, "seed2": 3})
|
|
||||||
|
|
||||||
|
|
||||||
def test_gamma_strategy_error():
|
|
||||||
"""
|
|
||||||
Features:test Gamma strategy error
|
|
||||||
Description: invalid strategy
|
|
||||||
Expectation: Raise RuntimeError
|
|
||||||
"""
|
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
|
||||||
shape = (8, 8)
|
|
||||||
strategy = ((2, 2), (2,), (1,))
|
|
||||||
net = Net(SEED_, SEED2_, strategy)
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
compile_net(net, shape, alpha_, beta_)
|
|
Loading…
Reference in New Issue