!45707 fix gamma ops dynamic rank bug

Merge pull request !45707 from zhangqi/1118
This commit is contained in:
i-robot 2022-11-25 08:25:03 +00:00 committed by Gitee
commit 98977f0366
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 38 additions and 11 deletions

View File

@ -14,7 +14,6 @@
# ============================================================================
"""Defines parameter operators with functional form."""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
@ -24,7 +23,6 @@ from mindspore.common.seed import _get_graph_seed
from mindspore.common.tensor import Tensor
from mindspore.ops.operations.random_ops import RandomShuffle, RandomChoiceWithMask
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.ops._utils import get_broadcast_shape
def random_gamma(shape, alpha, seed=0, seed2=0):
@ -67,16 +65,8 @@ def random_gamma(shape, alpha, seed=0, seed2=0):
(7, 5, 2)
"""
alpha_type = P.DType()(alpha)
beta = Tensor(np.array([1.0]), alpha_type)
alpha_shape = P.Shape()(alpha)
beta_shape = P.Shape()(beta)
broadcast_shape = get_broadcast_shape(alpha_shape, beta_shape, "random_gamma", arg_name1="alpha", arg_name2="beta")
broadcast_shape_t = tuple(broadcast_shape)
broadcast_to = P.BroadcastTo(broadcast_shape_t)
alpha_broadcast = broadcast_to(alpha)
random_gamma_op = _get_cache_prim(P.RandomGamma)(seed=seed, seed2=seed2)
output = random_gamma_op(shape, alpha_broadcast)
output = random_gamma_op(shape, alpha)
return output

View File

@ -20,6 +20,7 @@ import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
import mindspore.numpy as ms_np
class RandomGammaTEST(nn.Cell):
@ -38,6 +39,42 @@ class RandomGamma(nn.Cell):
return F.random_gamma(shape, alpha, seed)
class RandomGammaDR(nn.Cell):
def __init__(self):
super(RandomGammaDR, self).__init__()
self.reducesum = P.ReduceSum(keep_dims=False)
def construct(self, shape, alpha, seed=0):
axis = ms_np.randint(1, 2, (2,))
rand_axis = ms_np.unique(axis)
outshape = self.reducesum(shape, rand_axis)
outalpha = self.reducesum(alpha, rand_axis)
return F.random_gamma(outshape, outalpha, seed), outshape, outalpha
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_dynamic_rank():
"""
Feature: RandomGamma cpu kernel for dynamic rank
Description: test the random gamma alpha is a tensor.
Expectation: match to tensorflow benchmark.
"""
ms.set_context(mode=ms.GRAPH_MODE, device_target='CPU')
shape_ = Tensor(np.random.randint(low=1, high=8, size=(2, 2)), ms.int32)
alpha_ = Tensor(np.random.randint(low=1, high=5, size=(2, 2)), ms.float32)
net = RandomGammaDR()
input_dyn_shape = [None, None]
alpha_dyn_shape = [None, None]
net.set_inputs(Tensor(shape=input_dyn_shape, dtype=ms.int32),
Tensor(shape=alpha_dyn_shape, dtype=ms.float32))
output, out_s, out_a = net(shape_, alpha_)
expect = np.concatenate((out_s, np.shape(out_a)), axis=0)
assert (output.shape == expect).all()
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu