add random_gamma support dynamic shape and functional api
This commit is contained in:
parent
389dc9ac27
commit
ae6c8c807d
|
@ -313,6 +313,7 @@ Tensor创建
|
|||
:template: classtemplate.rst
|
||||
|
||||
mindspore.ops.gamma
|
||||
mindspore.ops.random_gamma
|
||||
mindspore.ops.multinomial
|
||||
mindspore.ops.poisson
|
||||
mindspore.ops.random_categorical
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
mindspore.ops.gamma
|
||||
====================
|
||||
|
||||
.. py:function:: mindspore.ops.gamma(shape, alpha, seed=None)
|
||||
|
||||
根据伽马分布产生成随机数。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **shape** (Tensor) - 指定生成随机数的shape。任意维度的Tensor。
|
||||
- **alpha** (Tensor) - :math:`\alpha` 分布的参数。应该大于0且数据类型为half、float32或者float64。
|
||||
- **seed** (int) - 随机数生成器的种子,必须是非负数,默认为None,将视为0。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor。shape是输入 `shape` 、 `alpha` 拼接后的shape。数据类型和alpha一致。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** – `shape` 不是Tensor。
|
||||
- **TypeError** – `alpha` 不是Tensor。
|
||||
- **TypeError** – `seed` 的数据类型不是int。
|
||||
- **TypeError** – `alpha` 的数据类型不是half、float32或者float64。
|
|
@ -312,6 +312,7 @@ Randomly Generating Operators
|
|||
:template: classtemplate.rst
|
||||
|
||||
mindspore.ops.gamma
|
||||
mindspore.ops.random_gamma
|
||||
mindspore.ops.multinomial
|
||||
mindspore.ops.poisson
|
||||
mindspore.ops.random_categorical
|
||||
|
|
|
@ -25,6 +25,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
static constexpr size_t INPUT_NUM = 2;
|
||||
static constexpr size_t OUTPUT_NUM = 1;
|
||||
bool GammaCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::RandomGamma>(base_operator);
|
||||
|
@ -37,9 +39,13 @@ bool GammaCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::ve
|
|||
seed2_ = kernel_ptr->get_seed2();
|
||||
generator_.Init(seed_, seed2_);
|
||||
|
||||
outputs_ = outputs;
|
||||
output_shape_ = outputs[0]->GetShapeVector();
|
||||
alpha_shape_ = inputs[1]->GetShapeVector();
|
||||
alpha_dtype_ = inputs[1]->GetDtype();
|
||||
shape_dtype_ = inputs[0]->GetDtype();
|
||||
|
||||
is_need_retrieve_output_shape_ = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -47,19 +53,19 @@ bool GammaCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::ve
|
|||
int GammaCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), INPUT_NUM, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), OUTPUT_NUM, kernel_name_);
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
|
||||
if (ret != 0) {
|
||||
|
||||
if (ret != KRET_OK) {
|
||||
dyamic_shape_ = ret == KRET_UNKNOWN_OUT_SHAPE;
|
||||
return ret;
|
||||
}
|
||||
auto input_shape_shape = inputs[0]->GetShapeVector();
|
||||
int shape_len = input_shape_shape.size();
|
||||
for (int i = 0; i < shape_len; i++) {
|
||||
if (input_shape_shape[i] != output_shape_[i]) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' output size and input size mismatch.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
|
||||
shape_shape_ = inputs[0]->GetShapeVector();
|
||||
alpha_shape_ = inputs[1]->GetShapeVector();
|
||||
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
// T: float16 float32 float64 dtype of alpha, beta and output
|
||||
|
@ -168,8 +174,35 @@ void GammaCpuKernelMod::Generate(const std::vector<AddressPtr> &inputs, const st
|
|||
ParallelLaunchAutoSearch(DoWork, num_alphas * num_samples, this, ¶llel_search_info_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void GammaCpuKernelMod::InferShape(const std::vector<AddressPtr> &inputs) {
|
||||
const auto *shape_value = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
|
||||
for (int64_t i = 0; i < shape_shape_[0]; i++) {
|
||||
output_shape_.emplace_back(static_cast<int64_t>(shape_value[i]));
|
||||
}
|
||||
for (size_t i = 0; i < alpha_shape_.size(); i++) {
|
||||
output_shape_.emplace_back(alpha_shape_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
bool GammaCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
if (dyamic_shape_) {
|
||||
output_shape_.clear();
|
||||
if (output_shape_.empty()) {
|
||||
if (shape_dtype_ == kNumberTypeInt32) {
|
||||
InferShape<int32_t>(inputs);
|
||||
} else if (shape_dtype_ == kNumberTypeInt64) {
|
||||
InferShape<int64_t>(inputs);
|
||||
}
|
||||
outputs_[0]->SetShapeVector(output_shape_);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' output size and input size mismatch.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (alpha_dtype_ == kNumberTypeFloat16) {
|
||||
Generate<float16>(inputs, outputs);
|
||||
} else if (alpha_dtype_ == kNumberTypeFloat32) {
|
||||
|
|
|
@ -46,20 +46,26 @@ class GammaCpuKernelMod : public NativeCpuKernelMod {
|
|||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
std::vector<KernelTensorPtr> GetOutputs() override { return outputs_; };
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void Generate(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
void InferShape(const std::vector<AddressPtr> &inputs);
|
||||
size_t seed_{0};
|
||||
size_t seed2_{0};
|
||||
|
||||
std::vector<int64_t> output_shape_;
|
||||
std::vector<int64_t> shape_shape_;
|
||||
std::vector<int64_t> alpha_shape_;
|
||||
ShapeVector output_shape_;
|
||||
ShapeVector shape_shape_;
|
||||
ShapeVector alpha_shape_;
|
||||
TypeId shape_dtype_{kTypeUnknown};
|
||||
TypeId alpha_dtype_{kTypeUnknown};
|
||||
|
||||
GuardedPhiloxRandom generator_;
|
||||
// Dealing with dynamic shapes
|
||||
bool dyamic_shape_{false};
|
||||
std::vector<KernelTensorPtr> outputs_{};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -51,7 +51,14 @@ abstract::ShapePtr GammaInferShape(const PrimitivePtr &primitive, const std::vec
|
|||
auto input_shape_value_ptr = input_shape->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(input_shape_value_ptr);
|
||||
auto shape_value_tensor = input_shape_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_value_tensor);
|
||||
// MS_EXCEPTION_IF_NULL(shape_value_tensor); Dealing with dynamic shapes
|
||||
if ((shape_value_tensor) == nullptr) {
|
||||
ShapeVector out_shape = {-2};
|
||||
ShapeVector infer_shape_min;
|
||||
ShapeVector infer_shape_max;
|
||||
infer_shape_min = infer_shape_max = {1};
|
||||
return std::make_shared<abstract::Shape>(out_shape, infer_shape_min, infer_shape_max);
|
||||
}
|
||||
|
||||
auto shape_type_element = input_args[kInputIndex0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
MS_EXCEPTION_IF_NULL(shape_type_element);
|
||||
|
|
|
@ -99,3 +99,5 @@ get_unsupported_dynamic_vmap_rule = \
|
|||
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.UniformInt)(get_unsupported_dynamic_vmap_rule)
|
||||
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.UniformReal)(get_unsupported_dynamic_vmap_rule)
|
||||
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.StandardNormal)(get_unsupported_dynamic_vmap_rule)
|
||||
get_unsupported_dynamic_vmap_rule = \
|
||||
vmap_rules_getters.register(P.RandomGamma)(get_unsupported_dynamic_vmap_rule)
|
||||
|
|
|
@ -26,7 +26,7 @@ from .clip_ops import clip_by_value, clip_by_global_norm
|
|||
from .multitype_ops.add_impl import hyper_add
|
||||
from .multitype_ops.ones_like_impl import ones_like
|
||||
from .multitype_ops.zeros_like_impl import zeros_like
|
||||
from .random_ops import normal, laplace, uniform, random_gamma, gamma, poisson, multinomial
|
||||
from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial
|
||||
from .math_ops import count_nonzero, tensor_dot, dot, batch_dot, matmul, cummin
|
||||
from .array_ops import repeat_elements, sequence_mask
|
||||
from .vmap_ops import _VmapGeneralPreprocess, _VmapGeneralRule
|
||||
|
@ -49,7 +49,6 @@ __all__ = [
|
|||
'laplace',
|
||||
'uniform',
|
||||
'gamma',
|
||||
'random_gamma',
|
||||
'poisson',
|
||||
'multinomial',
|
||||
'clip_by_value',
|
||||
|
|
|
@ -276,57 +276,6 @@ def gamma(shape, alpha, beta, seed=None):
|
|||
return value
|
||||
|
||||
|
||||
@_function_forbid_reuse
|
||||
def random_gamma(shape, alpha, seed=None):
|
||||
"""
|
||||
Outputs random values from the Gamma distribution(s) described by alpha.
|
||||
|
||||
Args:
|
||||
shape (Tensor): The shape of random tensor to be generated.
|
||||
Must be one of the following types: int32, int64. 1-D integer tensor.
|
||||
alpha (Tensor): The alpha α distribution parameter.
|
||||
A Tensor. Must be one of the following types: half, float32, float64.
|
||||
seed (int): Seed is used as entropy source for the random number engines to generate
|
||||
pseudo-random numbers, must be non-negative. Default: None, which will be treated as 0.
|
||||
|
||||
Returns:
|
||||
Tensor. The shape should be equal to the concat shape between the input `shape` and the broadcast
|
||||
of `alpha`.
|
||||
The dtype is the same type as alpha.
|
||||
|
||||
Raises:
|
||||
TypeError: If `shape` is not a Tensor.
|
||||
TypeError: If `alpha` is not a Tensor.
|
||||
TypeError: If `seed` is not an int.
|
||||
TypeError: If dtype of `alpha` is not half, float32 or float64.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>> import mindspore
|
||||
>>> # case 1: alpha_shape is (2, 2)
|
||||
>>> shape = Tensor(np.array([3, 1, 2]), mindspore.int32)
|
||||
>>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
|
||||
>>> output = ops.random_gamma(shape, alpha, seed=5)
|
||||
>>> result = output.shape
|
||||
>>> print(result)
|
||||
(3, 1, 2, 2, 2)
|
||||
>>> # case 2: alpha_shape is (2), so shape is (7, 5, 2)
|
||||
>>> shape = Tensor(np.array([7, 5]), mindspore.int32)
|
||||
>>> alpha = Tensor(np.array([0.5, 1.5]), mindspore.float32)
|
||||
>>> output = ops.random_gamma(shape, alpha, seed=5)
|
||||
>>> result = output.shape
|
||||
>>> print(result)
|
||||
(7, 5, 2)
|
||||
"""
|
||||
seed1, seed2 = _get_seed(seed, "random_gamma")
|
||||
random_gamma_v = P.RandomGamma(seed1, seed2)
|
||||
value = random_gamma_v(shape, alpha)
|
||||
return value
|
||||
|
||||
|
||||
@_function_forbid_reuse
|
||||
def poisson(shape, mean, seed=None):
|
||||
r"""
|
||||
|
|
|
@ -285,6 +285,7 @@ from .random_func import (
|
|||
random_categorical,
|
||||
uniform,
|
||||
standard_normal,
|
||||
random_gamma,
|
||||
)
|
||||
|
||||
__all__ = []
|
||||
|
|
|
@ -15,13 +15,72 @@
|
|||
|
||||
"""Defines parameter operators with functional form."""
|
||||
|
||||
from mindspore.ops.primitive import constexpr
|
||||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
||||
from ...common import dtype as mstype
|
||||
from ...common.seed import _get_graph_seed
|
||||
from ...common.tensor import Tensor
|
||||
from .._primitive_cache import _get_cache_prim
|
||||
from .._utils import get_broadcast_shape
|
||||
|
||||
|
||||
def random_gamma(shape, alpha, beta=None, seed=0, seed2=0):
|
||||
r"""
|
||||
Outputs random values from the Gamma distribution(s) described by alpha.
|
||||
It is defined as:
|
||||
|
||||
.. math::
|
||||
\text{P}(x|α,β) = \frac{\exp(-x/β)}{{β^α}\cdot{\Gamma(α)}}\cdot{x^{α-1}}
|
||||
|
||||
Args:
|
||||
shape (Tensor): The shape of random tensor to be generated.
|
||||
Must be one of the following types: int32, int64. 1-D integer tensor.
|
||||
alpha (Tensor): The alpha α distribution parameter.
|
||||
A Tensor. Must be one of the following types: half, float32, float64.
|
||||
seed (int): Seed is used as entropy source for the random number engines to generate
|
||||
pseudo-random numbers, must be non-negative. Default: None, which will be treated as 0.
|
||||
|
||||
Returns:
|
||||
Tensor. The shape should be equal to the concat shape between the input `shape` and the broadcast
|
||||
of `alpha`.
|
||||
The dtype is the same type as alpha.
|
||||
|
||||
Raises:
|
||||
TypeError: If `shape` is not a Tensor.
|
||||
TypeError: If `alpha` is not a Tensor.
|
||||
TypeError: If `seed` is not an int.
|
||||
TypeError: If dtype of `alpha` is not half, float32 or float64.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> shape = Tensor(np.array([7, 5]), mindspore.int32)
|
||||
>>> alpha = Tensor(np.array([0.5, 1.5]), mindspore.float32)
|
||||
>>> output = F.random_gamma(shape, alpha, seed=5)
|
||||
>>> result = output.shape
|
||||
>>> print(result)
|
||||
(7, 5, 2)
|
||||
"""
|
||||
|
||||
alpha_type = P.DType()(alpha)
|
||||
if beta is None:
|
||||
beta = Tensor(np.array([1.0]), alpha_type)
|
||||
alpha_shape = P.Shape()(alpha)
|
||||
beta_shape = P.Shape()(beta)
|
||||
broadcast_shape = get_broadcast_shape(alpha_shape, beta_shape, "random_gamma",
|
||||
arg_name1="alpha", arg_name2="beta")
|
||||
broadcast_shape_t = tuple(broadcast_shape)
|
||||
broadcast_to = P.BroadcastTo(broadcast_shape_t)
|
||||
alpha_broadcast = broadcast_to(alpha)
|
||||
random_gamma_op = _get_cache_prim(P.RandomGamma)(seed=seed, seed2=seed2)
|
||||
output = random_gamma_op(shape, alpha_broadcast)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@constexpr(reuse_result=False)
|
||||
|
@ -219,5 +278,6 @@ __all__ = [
|
|||
'random_categorical',
|
||||
'uniform',
|
||||
'standard_normal',
|
||||
'random_gamma'
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -18,6 +18,7 @@ import numpy as np
|
|||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
|
@ -34,7 +35,7 @@ class RandomGammaTEST(nn.Cell):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.parametrize("dtype", [np.float64, np.float32, np.float16])
|
||||
@pytest.mark.parametrize("dtype", [ms.float64, ms.float32, ms.float16])
|
||||
def test_random_gamma_op(dtype):
|
||||
"""
|
||||
Feature: RandomGamma cpu kernel
|
||||
|
@ -52,6 +53,38 @@ def test_random_gamma_op(dtype):
|
|||
print(output)
|
||||
assert (output.shape == expect).all()
|
||||
|
||||
gamma = P.RandomGamma(seed=3)
|
||||
output_2 = gamma(shape, alpha)
|
||||
print(output_2)
|
||||
assert (output_2.shape == expect).all()
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_random_gamma_op(np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.parametrize("dtype", [ms.float64, ms.float32, ms.float16])
|
||||
def test_random_gamma_functional(dtype):
|
||||
"""
|
||||
Feature: Functional interface of RandomGamma cpu kernel
|
||||
Description: test the random gamma alpha is a tensor.
|
||||
Expectation: match to tensorflow benchmark.
|
||||
"""
|
||||
|
||||
ms.set_context(mode=ms.GRAPH_MODE, device_target='CPU')
|
||||
shape = Tensor(np.array([3, 2]), ms.int32)
|
||||
alpha = Tensor(np.array([[3, 4], [5, 6]]), dtype)
|
||||
output = F.random_gamma(shape=shape, alpha=alpha, seed=2)
|
||||
expect = np.array([3, 2, 2, 2])
|
||||
|
||||
print(output)
|
||||
assert (output.shape == expect).all()
|
||||
|
||||
ms.set_context(mode=ms.PYNATIVE_MODE, device_target='CPU')
|
||||
shape = Tensor(np.array([4, 2]), ms.int32)
|
||||
alpha = Tensor(np.array([[3, 4], [5, 6]]), ms.float32)
|
||||
beta = Tensor(np.array([1, 2]), ms.float32)
|
||||
output = F.random_gamma(shape=shape, alpha=alpha, beta=beta, seed=1)
|
||||
expect = np.array([4, 2, 2, 2])
|
||||
|
||||
print(output)
|
||||
assert (output.shape == expect).all()
|
||||
|
|
Loading…
Reference in New Issue