!45707 fix gamma ops dynamic rank bug
Merge pull request !45707 from zhangqi/1118
This commit is contained in:
commit
98977f0366
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
"""Defines parameter operators with functional form."""
|
||||
|
||||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
@ -24,7 +23,6 @@ from mindspore.common.seed import _get_graph_seed
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops.operations.random_ops import RandomShuffle, RandomChoiceWithMask
|
||||
from mindspore.ops._primitive_cache import _get_cache_prim
|
||||
from mindspore.ops._utils import get_broadcast_shape
|
||||
|
||||
|
||||
def random_gamma(shape, alpha, seed=0, seed2=0):
|
||||
|
@ -67,16 +65,8 @@ def random_gamma(shape, alpha, seed=0, seed2=0):
|
|||
(7, 5, 2)
|
||||
"""
|
||||
|
||||
alpha_type = P.DType()(alpha)
|
||||
beta = Tensor(np.array([1.0]), alpha_type)
|
||||
alpha_shape = P.Shape()(alpha)
|
||||
beta_shape = P.Shape()(beta)
|
||||
broadcast_shape = get_broadcast_shape(alpha_shape, beta_shape, "random_gamma", arg_name1="alpha", arg_name2="beta")
|
||||
broadcast_shape_t = tuple(broadcast_shape)
|
||||
broadcast_to = P.BroadcastTo(broadcast_shape_t)
|
||||
alpha_broadcast = broadcast_to(alpha)
|
||||
random_gamma_op = _get_cache_prim(P.RandomGamma)(seed=seed, seed2=seed2)
|
||||
output = random_gamma_op(shape, alpha_broadcast)
|
||||
output = random_gamma_op(shape, alpha)
|
||||
|
||||
return output
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import mindspore.nn as nn
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
import mindspore.numpy as ms_np
|
||||
|
||||
|
||||
class RandomGammaTEST(nn.Cell):
|
||||
|
@ -38,6 +39,42 @@ class RandomGamma(nn.Cell):
|
|||
return F.random_gamma(shape, alpha, seed)
|
||||
|
||||
|
||||
class RandomGammaDR(nn.Cell):
|
||||
def __init__(self):
|
||||
super(RandomGammaDR, self).__init__()
|
||||
self.reducesum = P.ReduceSum(keep_dims=False)
|
||||
|
||||
def construct(self, shape, alpha, seed=0):
|
||||
axis = ms_np.randint(1, 2, (2,))
|
||||
rand_axis = ms_np.unique(axis)
|
||||
outshape = self.reducesum(shape, rand_axis)
|
||||
outalpha = self.reducesum(alpha, rand_axis)
|
||||
return F.random_gamma(outshape, outalpha, seed), outshape, outalpha
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_x86_cpu
|
||||
def test_dynamic_rank():
|
||||
"""
|
||||
Feature: RandomGamma cpu kernel for dynamic rank
|
||||
Description: test the random gamma alpha is a tensor.
|
||||
Expectation: match to tensorflow benchmark.
|
||||
"""
|
||||
|
||||
ms.set_context(mode=ms.GRAPH_MODE, device_target='CPU')
|
||||
shape_ = Tensor(np.random.randint(low=1, high=8, size=(2, 2)), ms.int32)
|
||||
alpha_ = Tensor(np.random.randint(low=1, high=5, size=(2, 2)), ms.float32)
|
||||
net = RandomGammaDR()
|
||||
input_dyn_shape = [None, None]
|
||||
alpha_dyn_shape = [None, None]
|
||||
net.set_inputs(Tensor(shape=input_dyn_shape, dtype=ms.int32),
|
||||
Tensor(shape=alpha_dyn_shape, dtype=ms.float32))
|
||||
output, out_s, out_a = net(shape_, alpha_)
|
||||
expect = np.concatenate((out_s, np.shape(out_a)), axis=0)
|
||||
assert (output.shape == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_x86_cpu
|
||||
|
|
Loading…
Reference in New Issue