[MS][LITE] StandardNormal ops dynamic support and vmap

This commit is contained in:
luoyuan 2022-06-27 21:50:28 +08:00
parent 2a953992f5
commit 4e0ea6566f
17 changed files with 174 additions and 48 deletions

View File

@ -318,6 +318,7 @@ Tensor创建
mindspore.ops.random_categorical mindspore.ops.random_categorical
mindspore.ops.standard_laplace mindspore.ops.standard_laplace
mindspore.ops.uniform mindspore.ops.uniform
mindspore.ops.standard_normal
Array操作 Array操作
^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^

View File

@ -5,26 +5,4 @@ mindspore.ops.StandardNormal
根据标准正态(高斯)随机数分布生成随机数。 根据标准正态(高斯)随机数分布生成随机数。
返回具有给定shape的Tensor其中的随机数从平均值为0、标准差为1的标准正态分布中取样。 更多参考详见 :func:`mindspore.ops.standard_normal`
.. math::
f(x)=\frac{1}{\sqrt{2 \pi}} e^{\left(-\frac{x^{2}}{2}\right)}
**参数:**
- **seed** (int) - 随机种子非负值。默认值0。
- **seed2** (int) - 随机种子2用来防止随机种子冲突非负值。默认值0。
**输入:**
- **shape** (tuple) - 目标随机数Tensor的shape。只允许常量值。
**输出:**
Tensor。shape为输入 `shape` 。数据类型支持float32。
**异常:**
- **TypeError** - `seed``seed2` 不是int类型。
- **TypeError** - `shape` 不是Tuple。
- **ValueError** - `shape` 不是常量值。

View File

@ -0,0 +1,27 @@
mindspore.ops.standard_normal
============================
.. py:function:: mindspore.ops.standard_normal(seed=0, seed2=0)
根据标准正态(高斯)随机数分布生成随机数。
返回具有给定shape的Tensor其中的随机数从平均值为0、标准差为1的标准正态分布中取样。
.. math::
f(x)=\frac{1}{\sqrt{2 \pi}} e^{\left(-\frac{x^{2}}{2}\right)}
**参数:**
- **shape** (tuple) - 目标随机数Tensor的shape。只允许常量值。
- **seed** (int) - 随机种子非负值。默认值0。
- **seed2** (int) - 随机种子2用来防止随机种子冲突非负值。默认值0。
**返回:**
Tensor。shape为输入 `shape` 。数据类型支持float32。
**异常:**
- **TypeError** - `seed``seed2` 不是int类型。
- **TypeError** - `shape` 不是Tuple。
- **ValueError** - `shape` 不是常量值。

View File

@ -317,6 +317,7 @@ Randomly Generating Operators
mindspore.ops.random_categorical mindspore.ops.random_categorical
mindspore.ops.standard_laplace mindspore.ops.standard_laplace
mindspore.ops.uniform mindspore.ops.uniform
mindspore.ops.standard_normal
Array Operation Array Operation
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^

View File

@ -359,6 +359,7 @@ constexpr const char kNameKLDiv[] = "KLDivLoss";
constexpr const char kNameStringLength[] = "StringLength"; constexpr const char kNameStringLength[] = "StringLength";
constexpr const char kNameGetShape[] = "GetShape"; constexpr const char kNameGetShape[] = "GetShape";
constexpr const char kNameKlDivLossGrad[] = "KLDivLossGrad"; constexpr const char kNameKlDivLossGrad[] = "KLDivLossGrad";
constexpr const char kNameRandomStandardNormal[] = "RandomStandardNormal";
class OpAdapterDesc; class OpAdapterDesc;

View File

@ -44,4 +44,11 @@ ATTR_MAP(TruncatedNormal) = {{"seed", ATTR_DESC(seed, AnyTraits<int64_t>())},
{"seed2", ATTR_DESC(seed2, AnyTraits<int64_t>())}}; {"seed2", ATTR_DESC(seed2, AnyTraits<int64_t>())}};
OUTPUT_MAP(TruncatedNormal) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(TruncatedNormal) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(TruncatedNormal, kNameTruncatedNormal, ADPT_DESC(TruncatedNormal)) REG_ADPT_DESC(TruncatedNormal, kNameTruncatedNormal, ADPT_DESC(TruncatedNormal))
// RandomStandardNormal
INPUT_MAP(RandomStandardNormal) = {{1, INPUT_DESC(shape)}};
ATTR_MAP(RandomStandardNormal) = {{"seed", ATTR_DESC(seed, AnyTraits<int64_t>())},
{"seed2", ATTR_DESC(seed2, AnyTraits<int64_t>())}};
OUTPUT_MAP(RandomStandardNormal) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(RandomStandardNormal, kNameRandomStandardNormal, ADPT_DESC(RandomStandardNormal))
} // namespace mindspore::transform } // namespace mindspore::transform

View File

@ -34,5 +34,8 @@ DECLARE_OP_USE_OUTPUT(RandomChoiceWithMask)
DECLARE_OP_ADAPTER(TruncatedNormal) DECLARE_OP_ADAPTER(TruncatedNormal)
DECLARE_OP_USE_OUTPUT(TruncatedNormal) DECLARE_OP_USE_OUTPUT(TruncatedNormal)
DECLARE_OP_ADAPTER(RandomStandardNormal)
DECLARE_OP_USE_OUTPUT(RandomStandardNormal)
} // namespace mindspore::transform } // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_RANDOM_OPS_DECLARE_H_ #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_RANDOM_OPS_DECLARE_H_

View File

@ -108,6 +108,7 @@ PrimShapeDependMap &GetHostDependsMap() {
static const auto &kResizeNearestNeighborGrad = prim::kPrimResizeNearestNeighborGrad->name(); static const auto &kResizeNearestNeighborGrad = prim::kPrimResizeNearestNeighborGrad->name();
static const auto &kSegmentMean = prim::kPrimSegmentMean->name(); static const auto &kSegmentMean = prim::kPrimSegmentMean->name();
static const auto &kSegmentProd = prim::kPrimSegmentProd->name(); static const auto &kSegmentProd = prim::kPrimSegmentProd->name();
static const auto &kStandardNormal = prim::kPrimStandardNormal->name();
// Common host depends. // Common host depends.
static PrimShapeDependMap host_depends{{kExtractGlimpse, ShapeSet{1}}, static PrimShapeDependMap host_depends{{kExtractGlimpse, ShapeSet{1}},
{kSegmentMax, ShapeSet{1}}, {kSegmentMax, ShapeSet{1}},
@ -164,7 +165,8 @@ PrimShapeDependMap &GetHostDependsMap() {
{kExpand, ShapeSet{1}}, {kExpand, ShapeSet{1}},
{kSspaddmm, ShapeSet{0, 2, 3, 5, 7}}, {kSspaddmm, ShapeSet{0, 2, 3, 5, 7}},
{kBartlettWindow, ShapeSet{0}}, {kBartlettWindow, ShapeSet{0}},
{kResizeNearestNeighborGrad, ShapeSet{1}}}; {kResizeNearestNeighborGrad, ShapeSet{1}},
{kStandardNormal, ShapeSet{0}}};
return host_depends; return host_depends;
} }
std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_num) { std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_num) {

View File

@ -229,6 +229,9 @@ constexpr auto kHSwishGrad = "HSwishGrad";
constexpr auto kSparseApplyAdagradDA = "SparseApplyAdagradDA"; constexpr auto kSparseApplyAdagradDA = "SparseApplyAdagradDA";
constexpr auto kMaxPool3DWithArgmax = "MaxPool3DWithArgmax"; constexpr auto kMaxPool3DWithArgmax = "MaxPool3DWithArgmax";
// Random
constexpr auto kStandardNormal = "StandardNormal";
// CSRTensor // CSRTensor
constexpr auto kMakeCSRTensor = "MakeCSRTensor"; constexpr auto kMakeCSRTensor = "MakeCSRTensor";
constexpr auto kCSRTensorGetValues = "CSRTensorGetValues"; constexpr auto kCSRTensorGetValues = "CSRTensorGetValues";
@ -1200,7 +1203,7 @@ GVAR_DEF(PrimitivePtr, kPrimDynamicBroadcastGradientArgs, std::make_shared<Primi
// Random // Random
GVAR_DEF(PrimitivePtr, kPrimStandardLaplace, std::make_shared<Primitive>("StandardLaplace")); GVAR_DEF(PrimitivePtr, kPrimStandardLaplace, std::make_shared<Primitive>("StandardLaplace"));
GVAR_DEF(PrimitivePtr, kPrimStandardNormal, std::make_shared<Primitive>("StandardNormal")); GVAR_DEF(PrimitivePtr, kPrimStandardNormal, std::make_shared<Primitive>(kStandardNormal));
GVAR_DEF(PrimitivePtr, kPrimRandomNormal, std::make_shared<Primitive>("RandomNormal")); GVAR_DEF(PrimitivePtr, kPrimRandomNormal, std::make_shared<Primitive>("RandomNormal"));
GVAR_DEF(PrimitivePtr, kPrimNonDeterministicInts, std::make_shared<Primitive>("NonDeterministicInts")); GVAR_DEF(PrimitivePtr, kPrimNonDeterministicInts, std::make_shared<Primitive>("NonDeterministicInts"));
GVAR_DEF(PrimitivePtr, kPrimTruncatedNormal, std::make_shared<Primitive>("TruncatedNormal")); GVAR_DEF(PrimitivePtr, kPrimTruncatedNormal, std::make_shared<Primitive>("TruncatedNormal"));

View File

@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "ops/random_standard_normal.h" #include "ops/random_standard_normal.h"
#include <set>
#include <string> #include <string>
#include <memory> #include <memory>
#include "ops/op_utils.h" #include "ops/op_utils.h"
@ -22,7 +23,6 @@
namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
MIND_API_OPERATOR_IMPL(RandomStandardNormal, BaseOperator);
void RandomStandardNormal::Init(const int64_t seed, const int64_t seed2) { void RandomStandardNormal::Init(const int64_t seed, const int64_t seed2) {
this->set_seed(seed); this->set_seed(seed);
this->set_seed2(seed2); this->set_seed2(seed2);
@ -41,6 +41,57 @@ int64_t RandomStandardNormal::get_seed2() const {
auto value_ptr = GetAttr(kSeed2); auto value_ptr = GetAttr(kSeed2);
return GetValue<int64_t>(value_ptr); return GetValue<int64_t>(value_ptr);
} }
REGISTER_PRIMITIVE_C(kNameRandomStandardNormal, RandomStandardNormal);
namespace {
abstract::ShapePtr RandomStandardNormalInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
ShapeVector shape;
abstract::ShapePtr output_shape;
auto shape_value = input_args[kInputIndex0]->BuildValue();
if (!shape_value->isa<AnyValue>() && !shape_value->isa<None>()) {
shape = shape_value->isa<ValueTuple>()
? CheckAndConvertUtils::CheckTupleInt("input[shape]", shape_value, prim_name)
: CheckAndConvertUtils::CheckTensorIntValue("input[shape]", shape_value, prim_name);
output_shape = std::make_shared<abstract::Shape>(shape);
} else {
constexpr int dynamic_rank_value = -2;
shape = {dynamic_rank_value}; // unknown dimension.
ShapeVector min_shape = {0};
ShapeVector max_shape = {abstract::Shape::SHP_ANY};
output_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
}
return output_shape;
}
TypePtr RandomStandardNormalInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
return std::make_shared<TensorType>(kFloat32);
}
} // namespace
MIND_API_OPERATOR_IMPL(RandomStandardNormal, BaseOperator);
AbstractBasePtr RandomStandardNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const int64_t kMinInputNum = 1;
const int64_t kMaxInputNum = 3;
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, kMinInputNum,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kLessEqual, kMaxInputNum,
prim_name);
auto type = RandomStandardNormalInferType(primitive, input_args);
auto shape = RandomStandardNormalInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_HOST_DEPENDS(kNameRandomStandardNormal, {0});
REGISTER_PRIMITIVE_EVAL_IMPL(RandomStandardNormal, prim::kPrimStandardNormal, RandomStandardNormalInfer, nullptr, true);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore

View File

@ -32,7 +32,7 @@ class MIND_API RandomStandardNormal : public BaseOperator {
public: public:
MIND_API_BASE_MEMBER(RandomStandardNormal); MIND_API_BASE_MEMBER(RandomStandardNormal);
/// \brief Constructor. /// \brief Constructor.
RandomStandardNormal() : BaseOperator(kNameRandomStandardNormal) {} RandomStandardNormal() : BaseOperator(kNameRandomStandardNormal) { InitIOName({"shape"}, {"output"}); }
/// \brief Method to init the op's attributes. /// \brief Method to init the op's attributes.
/// ///
@ -60,6 +60,9 @@ class MIND_API RandomStandardNormal : public BaseOperator {
/// \return seed2 attributes. /// \return seed2 attributes.
int64_t get_seed2() const; int64_t get_seed2() const;
}; };
abstract::AbstractBasePtr RandomStandardNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore

View File

@ -98,3 +98,4 @@ get_unsupported_dynamic_vmap_rule = \
vmap_rules_getters.register(P.StandardLaplace)(get_unsupported_dynamic_vmap_rule) vmap_rules_getters.register(P.StandardLaplace)(get_unsupported_dynamic_vmap_rule)
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.UniformInt)(get_unsupported_dynamic_vmap_rule) get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.UniformInt)(get_unsupported_dynamic_vmap_rule)
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.UniformReal)(get_unsupported_dynamic_vmap_rule) get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.UniformReal)(get_unsupported_dynamic_vmap_rule)
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.StandardNormal)(get_unsupported_dynamic_vmap_rule)

View File

@ -283,6 +283,7 @@ from .random_func import (
standard_laplace, standard_laplace,
random_categorical, random_categorical,
uniform, uniform,
standard_normal,
) )
__all__ = [] __all__ = []

View File

@ -176,9 +176,48 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
return value return value
def standard_normal(shape, seed=0, seed2=0):
r"""
Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
Returns the tensor with the given shape, the random numbers in it drawn from normal distributions
whose mean is 0 and standard deviation is 1.
.. math::
f(x)=\frac{1}{\sqrt{2 \pi}} e^{\left(-\frac{x^{2}}{2}\right)}
Args:
shape (tuple): The shape of random tensor to be generated. Only constant value is allowed.
seed (int): Random seed, must be non-negative. Default: 0.
seed2 (int): Random seed2, must be non-negative. A second seed to avoid seed collision. Default: 0.
Returns:
Tensor. The shape is the same as the input `shape`. The dtype is float32.
Raises:
TypeError: If neither `seed` nor `seed2` is an int.
TypeError: If `shape` is not a tuple.
ValueError: If `shape` is not a constant value.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore.ops import functional as F
>>> shape = (4, 4)
>>> output = F.standard_normal(shape)
>>> result = output.shape
>>> print(result)
(4, 4)
"""
standard_normal_op = _get_cache_prim(P.StandardNormal)(seed=seed, seed2=seed2)
return standard_normal_op(shape)
__all__ = [ __all__ = [
'standard_laplace', 'standard_laplace',
'random_categorical', 'random_categorical',
'uniform', 'uniform',
'standard_normal',
] ]
__all__.sort() __all__.sort()

View File

@ -1025,6 +1025,7 @@ tensor_operator_registry.register('standard_laplace', P.StandardLaplace)
tensor_operator_registry.register('split', P.Split) tensor_operator_registry.register('split', P.Split)
tensor_operator_registry.register('erf', P.Erf) tensor_operator_registry.register('erf', P.Erf)
tensor_operator_registry.register('erfc', P.Erfc) tensor_operator_registry.register('erfc', P.Erfc)
tensor_operator_registry.register('standard_normal', P.StandardNormal)
# ms cannot support Tensor(True) compare # ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal) tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal) tensor_operator_registry.register('__ne__', not_equal)

View File

@ -133,26 +133,7 @@ class StandardNormal(PrimitiveWithInfer):
r""" r"""
Generates random numbers according to the standard Normal (or Gaussian) random number distribution. Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
Returns the tensor with the given shape, the random numbers in it drawn from normal distributions Refer to :func:`mindspore.ops.standard_normal` for more detail.
whose mean is 0 and standard deviation is 1.
.. math::
f(x)=\frac{1}{\sqrt{2 \pi}} e^{\left(-\frac{x^{2}}{2}\right)}
Args:
seed (int): Random seed, must be non-negative. Default: 0.
seed2 (int): Random seed2, must be non-negative. A second seed to avoid seed collision. Default: 0.
Inputs:
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
Outputs:
Tensor. The shape is the same as the input `shape`. The dtype is float32.
Raises:
TypeError: If neither `seed` nor `seed2` is an int.
TypeError: If `shape` is not a tuple.
ValueError: If `shape` is not a constant value.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``

View File

@ -14,9 +14,11 @@
# ============================================================================ # ============================================================================
import pytest import pytest
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
@ -61,3 +63,27 @@ def test_net():
net = Net(shape, seed, seed2) net = Net(shape, seed, seed2)
output = net() output = net()
assert output.shape == (130, 120, 141) assert output.shape == (130, 120, 141)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_standard_normal_functional():
"""
Feature: Functional interface of StandardNormal CPU operation
Description: input the shape and random seed, test the output value and shape
Expectation: the value and shape of output tensor match the predefined values
"""
seed = 10
seed2 = 10
shape = (5, 6, 8)
output = F.standard_normal(shape, seed, seed2)
assert output.shape == shape
output_numpy_flatten_1 = output.asnumpy().flatten()
seed = 0
seed2 = 10
output = F.standard_normal(shape, seed, seed2)
assert output.shape == shape
output_numpy_flatten_2 = output.asnumpy().flatten()
assert (output_numpy_flatten_1 == output_numpy_flatten_2).all()