!49463 support randperm function

Merge pull request !49463 from 吕昱峰(Nate.River)/randperm
This commit is contained in:
i-robot 2023-02-28 01:46:27 +00:00 committed by Gitee
commit be9ceffb0a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 49 additions and 5 deletions

View File

@ -470,6 +470,7 @@ Tensor创建
mindspore.ops.randn
mindspore.ops.randn_like
mindspore.ops.random_poisson
mindspore.ops.randperm
mindspore.ops.standard_laplace
mindspore.ops.standard_normal
mindspore.ops.uniform

View File

@ -0,0 +1,28 @@
mindspore.ops.randperm
========================
.. py:function:: mindspore.ops.randperm(n, seed=0, offset=0, dtype=mstype.int64)
生成从 0 到 n-1 的整数随机排列。
返回由 n 推断出的具有确定形状的张量,其中的随机数取自给定类型可以表示的数据范围。
.. note::
`n` 必须大于0。
返回一个Tensorshape和dtype由输入决定其元素为服从标准正态分布的 :math:`[0, 1)` 区间的数字。
参数:
- **n** (Union[Tensor, int]) - 输入大小如果为Tensor则形状为()或(1,)数据类型为int64。
- **seed** (int可选): 随机种子。 默认值0。当seed为-1只有负值offset为0由时间决定。
- **offset** (int可选): 优先级高于随机种子。 默认值0。必须是非负数。
- **dtype** (:class:`mindspore.dtype`,可选)输出的类型。必须是以下类型之一int32、int16、int8、uint8、int64、float64、float32、float16。 默认值int64。
返回:
Tensorshape由参数 `n` 决定dtype由参数 `dtype` 决定。
异常:
- **TypeError** - 如果 `dtype` 不是一个 `mstype.float_type` 类型。
- **ValueError** - 如果 `n` 是负数或0。
- **ValueError** - 如果 `seed` 不是非负整数。
- **ValueError** - 如果 `n` 是超过指定数据类型的最大范围。

View File

@ -470,6 +470,7 @@ Randomly Generating Functions
mindspore.ops.randn
mindspore.ops.randn_like
mindspore.ops.random_poisson
mindspore.ops.randperm
mindspore.ops.standard_laplace
mindspore.ops.standard_normal
mindspore.ops.uniform

View File

@ -16,6 +16,7 @@
from __future__ import absolute_import
from mindspore import context
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr, _primexpr
@ -584,6 +585,11 @@ def choice_with_mask(input_x, count=256, seed=None):
return output
@constexpr
def is_cpu_backend():
return context.get_context('device_target') == 'CPU'
@_function_forbid_reuse
def randperm(n, seed=0, offset=0, dtype=mstype.int64):
r"""
@ -612,18 +618,26 @@ def randperm(n, seed=0, offset=0, dtype=mstype.int64):
ValueError: If `n` is larger than the maximal data of the set dtype.
Supported Platforms:
``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> n = Tensor([4], mstype.int64)
>>> n = 4
>>> seed = 0
>>> offset = 0
>>> output = ops.randperm(n, seed, offset, dtype=mstype.int64)
>>> print(output)
[1 0 2 3]
"""
if is_cpu_backend():
if isinstance(n, int):
n = Tensor(n)
randperm_ = _get_cache_prim(RandpermV2)(dtype=dtype)
return randperm_(n, seed, offset)
if isinstance(n, Tensor):
n = int(n)
randperm_ = _get_cache_prim(P.Randperm)(max_length=n, dtype=dtype)
return randperm_(Tensor((n,)))
@_function_forbid_reuse