add python implement and st for random gamma

This commit is contained in:
zhangqi 2022-05-31 09:33:32 +08:00
parent 45241fa3a6
commit d0afaf60f0
6 changed files with 101 additions and 132 deletions

View File

@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Operations for random number generators.""" """Operations for random number generators."""
import numpy as np
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
from ...common.tensor import Tensor
from .. import operations as P from .. import operations as P
from .. import functional as F from .. import functional as F
from .multitype_ops import _constexpr_utils as const_utils from .multitype_ops import _constexpr_utils as const_utils
@ -234,7 +236,7 @@ def gamma(shape, alpha, beta, seed=None):
TypeError: If dtype of `alpha` and `beta` is not float32. TypeError: If dtype of `alpha` and `beta` is not float32.
Supported Platforms: Supported Platforms:
``Ascend`` ``Ascend`` ``CPU``
Examples: Examples:
>>> from mindspore import Tensor, ops >>> from mindspore import Tensor, ops
@ -246,45 +248,41 @@ def gamma(shape, alpha, beta, seed=None):
>>> output = ops.gamma(shape, alpha, beta, seed=5) >>> output = ops.gamma(shape, alpha, beta, seed=5)
>>> result = output.shape >>> result = output.shape
>>> print(result) >>> print(result)
(3, 2, 2) (3, 1, 2, 2, 2)
>>> # case 2: alpha_shape is (2, 3), so shape is (3, 1, 3) >>> # case 2: alpha_shape is (2), so shape is (7, 5, 2)
>>> shape = (3, 1, 3) >>> shape = (7, 5)
>>> alpha = Tensor(np.array([[1, 3, 4], [2, 5, 6]]), mindspore.float32) >>> alpha = Tensor(np.array([0.5, 1.5]), mindspore.float32)
>>> beta = Tensor(np.array([1.0]), mindspore.float32) >>> beta = Tensor(np.array([1.0]), mindspore.float32)
>>> output = ops.gamma(shape, alpha, beta, seed=5) >>> output = ops.gamma(shape, alpha, beta, seed=5)
>>> result = output.shape >>> result = output.shape
>>> print(result) >>> print(result)
(3, 2, 3) (7, 5, 2)
>>> # case 3: beta_shape is (1, 2), the output is different. >>> # case 3: beta_shape is (1, 2), the output is different.
>>> shape = (3, 1, 2) >>> shape = (3, 1, 2)
>>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32) >>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
>>> beta = Tensor(np.array([1.0, 2]), mindspore.float32) >>> beta = Tensor(np.array([3.0, 2.0]), mindspore.float32)
>>> output = ops.gamma(shape, alpha, beta, seed=5) >>> output = ops.gamma(shape, alpha, beta, seed=3)
>>> result = output.shape >>> result = output.shape
>>> print(output) >>> print(output)
[[[ 2.2132034 5.8855834]] [[[[[0.8373873 1.4698703 ]
[ 3.3981476 7.5805717] [1.0850314 3.487788 ]]
[[ 3.3981476 7.5805717]] [[0.57389003 1.8903136 ]
[ 3.7190282 19.941492] [1.2278512 1.3656161 ]]]]
[[ 2.9512358 2.5969937]] [[[[0.12379696 1.9381095 ]
[ 3.786061 5.160872 ]]] [1.3704795 3.5111923 ]]
>>> # case 4: beta_shape is (2, 1), the output is different. [[0.49400368 1.9125801 ]
>>> shape = (3, 1, 2) [0.94508415 2.0883005 ]]]]
>>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32) [[[[0.5898374 1.1703413 ]
>>> beta = Tensor(np.array([[1.0], [2.0]]), mindspore.float32) [1.4078385 0.8582265 ]]
>>> output = ops.gamma(shape, alpha, beta, seed=5) [[0.5685522 1.4178807 ]
>>> result = output.shape [1.5442697 3.6673684 ]]]]]
>>> print(output)
[[[ 5.6085486 7.8280783]]
[ 15.97684 16.116285]
[[ 1.8347423 1.713663]]
[ 3.2434065 15.667398]
[[ 4.2922077 7.3365674]]
[ 5.3876944 13.159832 ]]]
""" """
seed1, seed2 = _get_seed(seed, "gamma") seed1, seed2 = _get_seed(seed, "gamma")
random_gamma = P.Gamma(seed1, seed2) random_gamma = P.Gamma(seed1, seed2)
value = random_gamma(shape, alpha, beta) alpha_type = F.dtype(alpha)
if beta is None:
beta = Tensor(np.array([1.0]), alpha_type)
value = random_gamma(shape, alpha, beta) / beta
return value return value

View File

@ -288,7 +288,7 @@ class Gamma(PrimitiveWithInfer):
ValueError: If `shape` is not a constant value. ValueError: If `shape` is not a constant value.
Supported Platforms: Supported Platforms:
``Ascend`` ``Ascend`` ``CPU``
Examples: Examples:
>>> shape = (3, 1, 2) >>> shape = (3, 1, 2)
@ -298,7 +298,7 @@ class Gamma(PrimitiveWithInfer):
>>> output = gamma(shape, alpha, beta) >>> output = gamma(shape, alpha, beta)
>>> result = output.shape >>> result = output.shape
>>> print(result) >>> print(result)
(3, 2, 2) (3, 1, 2, 2, 2)
""" """
@prim_attr_register @prim_attr_register
@ -320,10 +320,11 @@ class Gamma(PrimitiveWithInfer):
Validator.check_tensor_dtype_valid("beta", beta["dtype"], [mstype.float32], self.name) Validator.check_tensor_dtype_valid("beta", beta["dtype"], [mstype.float32], self.name)
broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name, broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name,
arg_name1="alpha", arg_name2="beta") arg_name1="alpha", arg_name2="beta")
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name, out_shape = list(shape_v)
arg_name1="broadcast_alpha_beta", arg_name2="shape") out_shape.extend(broadcast_shape)
out = { out = {
'shape': broadcast_shape, 'shape': out_shape,
'dtype': mstype.float32, 'dtype': mstype.float32,
'value': None} 'value': None}
return out return out
@ -447,7 +448,6 @@ class RandomPoisson(Primitive):
Validator.check_type_name("dtype", dtype, valid_values, self.name) Validator.check_type_name("dtype", dtype, valid_values, self.name)
class UniformInt(PrimitiveWithInfer): class UniformInt(PrimitiveWithInfer):
r""" r"""
Produces random integer values i, uniformly distributed on the closed interval [minval, maxval), that is, Produces random integer values i, uniformly distributed on the closed interval [minval, maxval), that is,

View File

@ -47,9 +47,9 @@ def test_net_1D():
def test_net_ND(): def test_net_ND():
seed = 10 seed = 10
shape = (3, 1, 2) shape = (3, 1, 2)
alpha = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32) alpha = np.array([[1, 2], [3, 4]]).astype(np.float32)
beta = np.array([1.0]).astype(np.float32) beta = np.array([1.0]).astype(np.float32)
net = Net(shape=shape, seed=seed) net = Net(shape=shape, seed=seed)
talpha, tbeta = Tensor(alpha), Tensor(beta) talpha, tbeta = Tensor(alpha), Tensor(beta)
output = net(talpha, tbeta) output = net(talpha, tbeta)
assert output.shape == (3, 2, 2) assert output.shape == (3, 1, 2, 2, 2)

View File

@ -24,6 +24,7 @@ from mindspore.common import set_seed
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
set_seed(20) set_seed(20)
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, shape, seed=0): def __init__(self, shape, seed=0):
super(Net, self).__init__() super(Net, self).__init__()
@ -48,9 +49,9 @@ def test_net_1D():
def test_net_ND(): def test_net_ND():
seed = 10 seed = 10
shape = (3, 1, 2) shape = (3, 1, 2)
alpha = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32) alpha = np.array([[1, 2], [3, 4]]).astype(np.float32)
beta = np.array([1.0]).astype(np.float32) beta = np.array([1.0]).astype(np.float32)
net = Net(shape, seed) net = Net(shape, seed)
talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32) talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32)
output = net(talpha, tbeta) output = net(talpha, tbeta)
assert output.shape == (3, 2, 2) assert output.shape == (3, 1, 2, 2, 2)

View File

@ -0,0 +1,64 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore.ops import composite as C
from mindspore import Tensor
class RandomGammaTEST(nn.Cell):
def __init__(self, seed=0):
super(RandomGammaTEST, self).__init__()
self.seed = seed
def construct(self, shape, alpha, beta):
return C.gamma(shape, alpha, beta, self.seed)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
@pytest.mark.parametrize("dtype", [np.float64, np.float32, np.float16])
def test_gamma_op(dtype):
"""
Feature: Gamma cpu kernel
Description: test the gamma beta is a tensor.
Expectation: match to tensorflow benchmark.
"""
shape = (3, 1, 2)
alpha = Tensor(np.array([[3, 4], [5, 6]]), ms.float32)
beta = Tensor(np.array([3.0, 2.0]), ms.float32)
gamma_test = RandomGammaTEST(seed=3)
expect = np.array([3, 1, 2, 2, 2])
ms.set_context(mode=ms.GRAPH_MODE, device_target='CPU')
output = gamma_test(shape, alpha, beta)
assert (output.shape == expect).all()
ms.set_context(mode=ms.PYNATIVE_MODE)
output = ms.ops.gamma(shape, alpha, beta)
assert (output.shape == expect).all()
ms.set_context(mode=ms.GRAPH_MODE, device_target='CPU')
output = gamma_test(shape, alpha, None)
assert (output.shape == expect).all()
ms.set_context(mode=ms.PYNATIVE_MODE)
output = ms.ops.gamma(shape, alpha, None)
assert (output.shape == expect).all()

View File

@ -1,94 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import mindspore.common.dtype as mstype
from mindspore import Tensor, context
from mindspore.nn import Cell
from mindspore.ops import operations as P
from parallel.utils.utils import ParallelValidator, compile_net
SEED_ = 1
SEED2_ = 1
alpha_ = Tensor(np.array([1.0]), mstype.float32)
beta_ = Tensor(np.array([1.0]), mstype.float32)
class Net(Cell):
def __init__(self, seed, seed2, strategy=None):
super(Net, self).__init__()
self.gamma = P.Gamma(seed, seed2).shard(strategy)
def construct(self, shape, alpha, beta):
out = self.gamma(shape, alpha, beta)
return out
def test_gamma_auto_parallel():
"""
Features: test Gamma auto parallel
Description: auto parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
net = Net(SEED_, SEED2_)
shape = (4, 4, 4)
compile_net(net, shape, alpha_, beta_)
def test_gamma_data_parallel():
"""
Features: test Gamma data parallel
Description: data parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
net = Net(SEED_, SEED2_)
shape = (8, 8)
phase = compile_net(net, shape, alpha_, beta_)
validator = ParallelValidator(net, phase)
assert validator.check_node_attrs("Gamma-0", {"seed": 2, "seed2": 2})
def test_gamma_model_parallel():
"""
Features: test Gamma model parallel
Description: model parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
shape = (8, 8)
strategy = ((2, 2), (1,), (1,))
net = Net(SEED_, SEED2_, strategy)
phase = compile_net(net, shape, alpha_, beta_)
validator = ParallelValidator(net, phase)
assert validator.check_node_attrs("Gamma-0", {"seed": 3, "seed2": 3})
def test_gamma_strategy_error():
"""
Features:test Gamma strategy error
Description: invalid strategy
Expectation: Raise RuntimeError
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
shape = (8, 8)
strategy = ((2, 2), (2,), (1,))
net = Net(SEED_, SEED2_, strategy)
with pytest.raises(RuntimeError):
compile_net(net, shape, alpha_, beta_)