[MS][LITE] StandardNormal ops dynamic support and vmap
This commit is contained in:
parent
2a953992f5
commit
4e0ea6566f
|
@ -318,6 +318,7 @@ Tensor创建
|
||||||
mindspore.ops.random_categorical
|
mindspore.ops.random_categorical
|
||||||
mindspore.ops.standard_laplace
|
mindspore.ops.standard_laplace
|
||||||
mindspore.ops.uniform
|
mindspore.ops.uniform
|
||||||
|
mindspore.ops.standard_normal
|
||||||
|
|
||||||
Array操作
|
Array操作
|
||||||
^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
|
@ -5,26 +5,4 @@ mindspore.ops.StandardNormal
|
||||||
|
|
||||||
根据标准正态(高斯)随机数分布生成随机数。
|
根据标准正态(高斯)随机数分布生成随机数。
|
||||||
|
|
||||||
返回具有给定shape的Tensor,其中的随机数从平均值为0、标准差为1的标准正态分布中取样。
|
更多参考详见 :func:`mindspore.ops.standard_normal`。
|
||||||
|
|
||||||
.. math::
|
|
||||||
f(x)=\frac{1}{\sqrt{2 \pi}} e^{\left(-\frac{x^{2}}{2}\right)}
|
|
||||||
|
|
||||||
**参数:**
|
|
||||||
|
|
||||||
- **seed** (int) - 随机种子,非负值。默认值:0。
|
|
||||||
- **seed2** (int) - 随机种子2,用来防止随机种子冲突,非负值。默认值:0。
|
|
||||||
|
|
||||||
**输入:**
|
|
||||||
|
|
||||||
- **shape** (tuple) - 目标随机数Tensor的shape。只允许常量值。
|
|
||||||
|
|
||||||
**输出:**
|
|
||||||
|
|
||||||
Tensor。shape为输入 `shape` 。数据类型支持float32。
|
|
||||||
|
|
||||||
**异常:**
|
|
||||||
|
|
||||||
- **TypeError** - `seed` 或 `seed2` 不是int类型。
|
|
||||||
- **TypeError** - `shape` 不是Tuple。
|
|
||||||
- **ValueError** - `shape` 不是常量值。
|
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
mindspore.ops.standard_normal
|
||||||
|
============================
|
||||||
|
|
||||||
|
.. py:function:: mindspore.ops.standard_normal(seed=0, seed2=0)
|
||||||
|
|
||||||
|
根据标准正态(高斯)随机数分布生成随机数。
|
||||||
|
|
||||||
|
返回具有给定shape的Tensor,其中的随机数从平均值为0、标准差为1的标准正态分布中取样。
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
f(x)=\frac{1}{\sqrt{2 \pi}} e^{\left(-\frac{x^{2}}{2}\right)}
|
||||||
|
|
||||||
|
**参数:**
|
||||||
|
|
||||||
|
- **shape** (tuple) - 目标随机数Tensor的shape。只允许常量值。
|
||||||
|
- **seed** (int) - 随机种子,非负值。默认值:0。
|
||||||
|
- **seed2** (int) - 随机种子2,用来防止随机种子冲突,非负值。默认值:0。
|
||||||
|
|
||||||
|
**返回:**
|
||||||
|
|
||||||
|
Tensor。shape为输入 `shape` 。数据类型支持float32。
|
||||||
|
|
||||||
|
**异常:**
|
||||||
|
|
||||||
|
- **TypeError** - `seed` 或 `seed2` 不是int类型。
|
||||||
|
- **TypeError** - `shape` 不是Tuple。
|
||||||
|
- **ValueError** - `shape` 不是常量值。
|
|
@ -317,6 +317,7 @@ Randomly Generating Operators
|
||||||
mindspore.ops.random_categorical
|
mindspore.ops.random_categorical
|
||||||
mindspore.ops.standard_laplace
|
mindspore.ops.standard_laplace
|
||||||
mindspore.ops.uniform
|
mindspore.ops.uniform
|
||||||
|
mindspore.ops.standard_normal
|
||||||
|
|
||||||
Array Operation
|
Array Operation
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
|
@ -359,6 +359,7 @@ constexpr const char kNameKLDiv[] = "KLDivLoss";
|
||||||
constexpr const char kNameStringLength[] = "StringLength";
|
constexpr const char kNameStringLength[] = "StringLength";
|
||||||
constexpr const char kNameGetShape[] = "GetShape";
|
constexpr const char kNameGetShape[] = "GetShape";
|
||||||
constexpr const char kNameKlDivLossGrad[] = "KLDivLossGrad";
|
constexpr const char kNameKlDivLossGrad[] = "KLDivLossGrad";
|
||||||
|
constexpr const char kNameRandomStandardNormal[] = "RandomStandardNormal";
|
||||||
|
|
||||||
class OpAdapterDesc;
|
class OpAdapterDesc;
|
||||||
|
|
||||||
|
|
|
@ -44,4 +44,11 @@ ATTR_MAP(TruncatedNormal) = {{"seed", ATTR_DESC(seed, AnyTraits<int64_t>())},
|
||||||
{"seed2", ATTR_DESC(seed2, AnyTraits<int64_t>())}};
|
{"seed2", ATTR_DESC(seed2, AnyTraits<int64_t>())}};
|
||||||
OUTPUT_MAP(TruncatedNormal) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(TruncatedNormal) = {{0, OUTPUT_DESC(y)}};
|
||||||
REG_ADPT_DESC(TruncatedNormal, kNameTruncatedNormal, ADPT_DESC(TruncatedNormal))
|
REG_ADPT_DESC(TruncatedNormal, kNameTruncatedNormal, ADPT_DESC(TruncatedNormal))
|
||||||
|
|
||||||
|
// RandomStandardNormal
|
||||||
|
INPUT_MAP(RandomStandardNormal) = {{1, INPUT_DESC(shape)}};
|
||||||
|
ATTR_MAP(RandomStandardNormal) = {{"seed", ATTR_DESC(seed, AnyTraits<int64_t>())},
|
||||||
|
{"seed2", ATTR_DESC(seed2, AnyTraits<int64_t>())}};
|
||||||
|
OUTPUT_MAP(RandomStandardNormal) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
REG_ADPT_DESC(RandomStandardNormal, kNameRandomStandardNormal, ADPT_DESC(RandomStandardNormal))
|
||||||
} // namespace mindspore::transform
|
} // namespace mindspore::transform
|
||||||
|
|
|
@ -34,5 +34,8 @@ DECLARE_OP_USE_OUTPUT(RandomChoiceWithMask)
|
||||||
|
|
||||||
DECLARE_OP_ADAPTER(TruncatedNormal)
|
DECLARE_OP_ADAPTER(TruncatedNormal)
|
||||||
DECLARE_OP_USE_OUTPUT(TruncatedNormal)
|
DECLARE_OP_USE_OUTPUT(TruncatedNormal)
|
||||||
|
|
||||||
|
DECLARE_OP_ADAPTER(RandomStandardNormal)
|
||||||
|
DECLARE_OP_USE_OUTPUT(RandomStandardNormal)
|
||||||
} // namespace mindspore::transform
|
} // namespace mindspore::transform
|
||||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_RANDOM_OPS_DECLARE_H_
|
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_RANDOM_OPS_DECLARE_H_
|
||||||
|
|
|
@ -108,6 +108,7 @@ PrimShapeDependMap &GetHostDependsMap() {
|
||||||
static const auto &kResizeNearestNeighborGrad = prim::kPrimResizeNearestNeighborGrad->name();
|
static const auto &kResizeNearestNeighborGrad = prim::kPrimResizeNearestNeighborGrad->name();
|
||||||
static const auto &kSegmentMean = prim::kPrimSegmentMean->name();
|
static const auto &kSegmentMean = prim::kPrimSegmentMean->name();
|
||||||
static const auto &kSegmentProd = prim::kPrimSegmentProd->name();
|
static const auto &kSegmentProd = prim::kPrimSegmentProd->name();
|
||||||
|
static const auto &kStandardNormal = prim::kPrimStandardNormal->name();
|
||||||
// Common host depends.
|
// Common host depends.
|
||||||
static PrimShapeDependMap host_depends{{kExtractGlimpse, ShapeSet{1}},
|
static PrimShapeDependMap host_depends{{kExtractGlimpse, ShapeSet{1}},
|
||||||
{kSegmentMax, ShapeSet{1}},
|
{kSegmentMax, ShapeSet{1}},
|
||||||
|
@ -164,7 +165,8 @@ PrimShapeDependMap &GetHostDependsMap() {
|
||||||
{kExpand, ShapeSet{1}},
|
{kExpand, ShapeSet{1}},
|
||||||
{kSspaddmm, ShapeSet{0, 2, 3, 5, 7}},
|
{kSspaddmm, ShapeSet{0, 2, 3, 5, 7}},
|
||||||
{kBartlettWindow, ShapeSet{0}},
|
{kBartlettWindow, ShapeSet{0}},
|
||||||
{kResizeNearestNeighborGrad, ShapeSet{1}}};
|
{kResizeNearestNeighborGrad, ShapeSet{1}},
|
||||||
|
{kStandardNormal, ShapeSet{0}}};
|
||||||
return host_depends;
|
return host_depends;
|
||||||
}
|
}
|
||||||
std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_num) {
|
std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_num) {
|
||||||
|
|
|
@ -229,6 +229,9 @@ constexpr auto kHSwishGrad = "HSwishGrad";
|
||||||
constexpr auto kSparseApplyAdagradDA = "SparseApplyAdagradDA";
|
constexpr auto kSparseApplyAdagradDA = "SparseApplyAdagradDA";
|
||||||
constexpr auto kMaxPool3DWithArgmax = "MaxPool3DWithArgmax";
|
constexpr auto kMaxPool3DWithArgmax = "MaxPool3DWithArgmax";
|
||||||
|
|
||||||
|
// Random
|
||||||
|
constexpr auto kStandardNormal = "StandardNormal";
|
||||||
|
|
||||||
// CSRTensor
|
// CSRTensor
|
||||||
constexpr auto kMakeCSRTensor = "MakeCSRTensor";
|
constexpr auto kMakeCSRTensor = "MakeCSRTensor";
|
||||||
constexpr auto kCSRTensorGetValues = "CSRTensorGetValues";
|
constexpr auto kCSRTensorGetValues = "CSRTensorGetValues";
|
||||||
|
@ -1200,7 +1203,7 @@ GVAR_DEF(PrimitivePtr, kPrimDynamicBroadcastGradientArgs, std::make_shared<Primi
|
||||||
|
|
||||||
// Random
|
// Random
|
||||||
GVAR_DEF(PrimitivePtr, kPrimStandardLaplace, std::make_shared<Primitive>("StandardLaplace"));
|
GVAR_DEF(PrimitivePtr, kPrimStandardLaplace, std::make_shared<Primitive>("StandardLaplace"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimStandardNormal, std::make_shared<Primitive>("StandardNormal"));
|
GVAR_DEF(PrimitivePtr, kPrimStandardNormal, std::make_shared<Primitive>(kStandardNormal));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimRandomNormal, std::make_shared<Primitive>("RandomNormal"));
|
GVAR_DEF(PrimitivePtr, kPrimRandomNormal, std::make_shared<Primitive>("RandomNormal"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimNonDeterministicInts, std::make_shared<Primitive>("NonDeterministicInts"));
|
GVAR_DEF(PrimitivePtr, kPrimNonDeterministicInts, std::make_shared<Primitive>("NonDeterministicInts"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimTruncatedNormal, std::make_shared<Primitive>("TruncatedNormal"));
|
GVAR_DEF(PrimitivePtr, kPrimTruncatedNormal, std::make_shared<Primitive>("TruncatedNormal"));
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "ops/random_standard_normal.h"
|
#include "ops/random_standard_normal.h"
|
||||||
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
|
@ -22,7 +23,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
MIND_API_OPERATOR_IMPL(RandomStandardNormal, BaseOperator);
|
|
||||||
void RandomStandardNormal::Init(const int64_t seed, const int64_t seed2) {
|
void RandomStandardNormal::Init(const int64_t seed, const int64_t seed2) {
|
||||||
this->set_seed(seed);
|
this->set_seed(seed);
|
||||||
this->set_seed2(seed2);
|
this->set_seed2(seed2);
|
||||||
|
@ -41,6 +41,57 @@ int64_t RandomStandardNormal::get_seed2() const {
|
||||||
auto value_ptr = GetAttr(kSeed2);
|
auto value_ptr = GetAttr(kSeed2);
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameRandomStandardNormal, RandomStandardNormal);
|
|
||||||
|
namespace {
|
||||||
|
abstract::ShapePtr RandomStandardNormalInferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||||
|
ShapeVector shape;
|
||||||
|
abstract::ShapePtr output_shape;
|
||||||
|
auto shape_value = input_args[kInputIndex0]->BuildValue();
|
||||||
|
if (!shape_value->isa<AnyValue>() && !shape_value->isa<None>()) {
|
||||||
|
shape = shape_value->isa<ValueTuple>()
|
||||||
|
? CheckAndConvertUtils::CheckTupleInt("input[shape]", shape_value, prim_name)
|
||||||
|
: CheckAndConvertUtils::CheckTensorIntValue("input[shape]", shape_value, prim_name);
|
||||||
|
output_shape = std::make_shared<abstract::Shape>(shape);
|
||||||
|
} else {
|
||||||
|
constexpr int dynamic_rank_value = -2;
|
||||||
|
shape = {dynamic_rank_value}; // unknown dimension.
|
||||||
|
ShapeVector min_shape = {0};
|
||||||
|
ShapeVector max_shape = {abstract::Shape::SHP_ANY};
|
||||||
|
output_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
|
||||||
|
}
|
||||||
|
return output_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr RandomStandardNormalInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
return std::make_shared<TensorType>(kFloat32);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
MIND_API_OPERATOR_IMPL(RandomStandardNormal, BaseOperator);
|
||||||
|
|
||||||
|
AbstractBasePtr RandomStandardNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
const int64_t kMinInputNum = 1;
|
||||||
|
const int64_t kMaxInputNum = 3;
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, kMinInputNum,
|
||||||
|
prim_name);
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kLessEqual, kMaxInputNum,
|
||||||
|
prim_name);
|
||||||
|
auto type = RandomStandardNormalInferType(primitive, input_args);
|
||||||
|
auto shape = RandomStandardNormalInferShape(primitive, input_args);
|
||||||
|
return abstract::MakeAbstract(shape, type);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_HOST_DEPENDS(kNameRandomStandardNormal, {0});
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(RandomStandardNormal, prim::kPrimStandardNormal, RandomStandardNormalInfer, nullptr, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -32,7 +32,7 @@ class MIND_API RandomStandardNormal : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
MIND_API_BASE_MEMBER(RandomStandardNormal);
|
MIND_API_BASE_MEMBER(RandomStandardNormal);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
RandomStandardNormal() : BaseOperator(kNameRandomStandardNormal) {}
|
RandomStandardNormal() : BaseOperator(kNameRandomStandardNormal) { InitIOName({"shape"}, {"output"}); }
|
||||||
|
|
||||||
/// \brief Method to init the op's attributes.
|
/// \brief Method to init the op's attributes.
|
||||||
///
|
///
|
||||||
|
@ -60,6 +60,9 @@ class MIND_API RandomStandardNormal : public BaseOperator {
|
||||||
/// \return seed2 attributes.
|
/// \return seed2 attributes.
|
||||||
int64_t get_seed2() const;
|
int64_t get_seed2() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr RandomStandardNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -98,3 +98,4 @@ get_unsupported_dynamic_vmap_rule = \
|
||||||
vmap_rules_getters.register(P.StandardLaplace)(get_unsupported_dynamic_vmap_rule)
|
vmap_rules_getters.register(P.StandardLaplace)(get_unsupported_dynamic_vmap_rule)
|
||||||
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.UniformInt)(get_unsupported_dynamic_vmap_rule)
|
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.UniformInt)(get_unsupported_dynamic_vmap_rule)
|
||||||
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.UniformReal)(get_unsupported_dynamic_vmap_rule)
|
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.UniformReal)(get_unsupported_dynamic_vmap_rule)
|
||||||
|
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.StandardNormal)(get_unsupported_dynamic_vmap_rule)
|
||||||
|
|
|
@ -283,6 +283,7 @@ from .random_func import (
|
||||||
standard_laplace,
|
standard_laplace,
|
||||||
random_categorical,
|
random_categorical,
|
||||||
uniform,
|
uniform,
|
||||||
|
standard_normal,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
|
|
|
@ -176,9 +176,48 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def standard_normal(shape, seed=0, seed2=0):
|
||||||
|
r"""
|
||||||
|
Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
|
||||||
|
|
||||||
|
Returns the tensor with the given shape, the random numbers in it drawn from normal distributions
|
||||||
|
whose mean is 0 and standard deviation is 1.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
f(x)=\frac{1}{\sqrt{2 \pi}} e^{\left(-\frac{x^{2}}{2}\right)}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape (tuple): The shape of random tensor to be generated. Only constant value is allowed.
|
||||||
|
seed (int): Random seed, must be non-negative. Default: 0.
|
||||||
|
seed2 (int): Random seed2, must be non-negative. A second seed to avoid seed collision. Default: 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor. The shape is the same as the input `shape`. The dtype is float32.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If neither `seed` nor `seed2` is an int.
|
||||||
|
TypeError: If `shape` is not a tuple.
|
||||||
|
ValueError: If `shape` is not a constant value.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from mindspore.ops import functional as F
|
||||||
|
>>> shape = (4, 4)
|
||||||
|
>>> output = F.standard_normal(shape)
|
||||||
|
>>> result = output.shape
|
||||||
|
>>> print(result)
|
||||||
|
(4, 4)
|
||||||
|
"""
|
||||||
|
standard_normal_op = _get_cache_prim(P.StandardNormal)(seed=seed, seed2=seed2)
|
||||||
|
return standard_normal_op(shape)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'standard_laplace',
|
'standard_laplace',
|
||||||
'random_categorical',
|
'random_categorical',
|
||||||
'uniform',
|
'uniform',
|
||||||
|
'standard_normal',
|
||||||
]
|
]
|
||||||
__all__.sort()
|
__all__.sort()
|
||||||
|
|
|
@ -1025,6 +1025,7 @@ tensor_operator_registry.register('standard_laplace', P.StandardLaplace)
|
||||||
tensor_operator_registry.register('split', P.Split)
|
tensor_operator_registry.register('split', P.Split)
|
||||||
tensor_operator_registry.register('erf', P.Erf)
|
tensor_operator_registry.register('erf', P.Erf)
|
||||||
tensor_operator_registry.register('erfc', P.Erfc)
|
tensor_operator_registry.register('erfc', P.Erfc)
|
||||||
|
tensor_operator_registry.register('standard_normal', P.StandardNormal)
|
||||||
# ms cannot support Tensor(True) compare
|
# ms cannot support Tensor(True) compare
|
||||||
tensor_operator_registry.register('__eq__', equal)
|
tensor_operator_registry.register('__eq__', equal)
|
||||||
tensor_operator_registry.register('__ne__', not_equal)
|
tensor_operator_registry.register('__ne__', not_equal)
|
||||||
|
|
|
@ -133,26 +133,7 @@ class StandardNormal(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
|
Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
|
||||||
|
|
||||||
Returns the tensor with the given shape, the random numbers in it drawn from normal distributions
|
Refer to :func:`mindspore.ops.standard_normal` for more detail.
|
||||||
whose mean is 0 and standard deviation is 1.
|
|
||||||
|
|
||||||
.. math::
|
|
||||||
f(x)=\frac{1}{\sqrt{2 \pi}} e^{\left(-\frac{x^{2}}{2}\right)}
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed (int): Random seed, must be non-negative. Default: 0.
|
|
||||||
seed2 (int): Random seed2, must be non-negative. A second seed to avoid seed collision. Default: 0.
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
Tensor. The shape is the same as the input `shape`. The dtype is float32.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If neither `seed` nor `seed2` is an int.
|
|
||||||
TypeError: If `shape` is not a tuple.
|
|
||||||
ValueError: If `shape` is not a constant value.
|
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU`` ``CPU``
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
|
@ -14,9 +14,11 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
|
|
||||||
|
@ -61,3 +63,27 @@ def test_net():
|
||||||
net = Net(shape, seed, seed2)
|
net = Net(shape, seed, seed2)
|
||||||
output = net()
|
output = net()
|
||||||
assert output.shape == (130, 120, 141)
|
assert output.shape == (130, 120, 141)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_standard_normal_functional():
|
||||||
|
"""
|
||||||
|
Feature: Functional interface of StandardNormal CPU operation
|
||||||
|
Description: input the shape and random seed, test the output value and shape
|
||||||
|
Expectation: the value and shape of output tensor match the predefined values
|
||||||
|
"""
|
||||||
|
seed = 10
|
||||||
|
seed2 = 10
|
||||||
|
shape = (5, 6, 8)
|
||||||
|
output = F.standard_normal(shape, seed, seed2)
|
||||||
|
assert output.shape == shape
|
||||||
|
output_numpy_flatten_1 = output.asnumpy().flatten()
|
||||||
|
|
||||||
|
seed = 0
|
||||||
|
seed2 = 10
|
||||||
|
output = F.standard_normal(shape, seed, seed2)
|
||||||
|
assert output.shape == shape
|
||||||
|
output_numpy_flatten_2 = output.asnumpy().flatten()
|
||||||
|
assert (output_numpy_flatten_1 == output_numpy_flatten_2).all()
|
||||||
|
|
Loading…
Reference in New Issue