!49463 support randperm function
Merge pull request !49463 from 吕昱峰(Nate.River)/randperm
This commit is contained in:
commit
be9ceffb0a
|
@ -470,6 +470,7 @@ Tensor创建
|
|||
mindspore.ops.randn
|
||||
mindspore.ops.randn_like
|
||||
mindspore.ops.random_poisson
|
||||
mindspore.ops.randperm
|
||||
mindspore.ops.standard_laplace
|
||||
mindspore.ops.standard_normal
|
||||
mindspore.ops.uniform
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
mindspore.ops.randperm
|
||||
========================
|
||||
|
||||
.. py:function:: mindspore.ops.randperm(n, seed=0, offset=0, dtype=mstype.int64)
|
||||
|
||||
生成从 0 到 n-1 的整数随机排列。
|
||||
|
||||
返回由 n 推断出的具有确定形状的张量,其中的随机数取自给定类型可以表示的数据范围。
|
||||
|
||||
.. note::
|
||||
`n` 必须大于0。
|
||||
|
||||
返回一个Tensor,shape和dtype由输入决定,其元素为服从标准正态分布的 :math:`[0, 1)` 区间的数字。
|
||||
|
||||
参数:
|
||||
- **n** (Union[Tensor, int]) - 输入大小,如果为Tensor,则形状为()或(1,),数据类型为int64。
|
||||
- **seed** (int,可选): 随机种子。 默认值:0。当seed为-1(只有负值)时,offset为0,由时间决定。
|
||||
- **offset** (int,可选): 优先级高于随机种子。 默认值:0。必须是非负数。
|
||||
- **dtype** (:class:`mindspore.dtype`,可选):输出的类型。必须是以下类型之一:int32、int16、int8、uint8、int64、float64、float32、float16。 默认值:int64。
|
||||
|
||||
返回:
|
||||
Tensor,shape由参数 `n` 决定,dtype由参数 `dtype` 决定。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `dtype` 不是一个 `mstype.float_type` 类型。
|
||||
- **ValueError** - 如果 `n` 是负数或0。
|
||||
- **ValueError** - 如果 `seed` 不是非负整数。
|
||||
- **ValueError** - 如果 `n` 是超过指定数据类型的最大范围。
|
|
@ -470,6 +470,7 @@ Randomly Generating Functions
|
|||
mindspore.ops.randn
|
||||
mindspore.ops.randn_like
|
||||
mindspore.ops.random_poisson
|
||||
mindspore.ops.randperm
|
||||
mindspore.ops.standard_laplace
|
||||
mindspore.ops.standard_normal
|
||||
mindspore.ops.uniform
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
from __future__ import absolute_import
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.primitive import constexpr, _primexpr
|
||||
|
@ -584,6 +585,11 @@ def choice_with_mask(input_x, count=256, seed=None):
|
|||
return output
|
||||
|
||||
|
||||
@constexpr
|
||||
def is_cpu_backend():
|
||||
return context.get_context('device_target') == 'CPU'
|
||||
|
||||
|
||||
@_function_forbid_reuse
|
||||
def randperm(n, seed=0, offset=0, dtype=mstype.int64):
|
||||
r"""
|
||||
|
@ -612,18 +618,26 @@ def randperm(n, seed=0, offset=0, dtype=mstype.int64):
|
|||
ValueError: If `n` is larger than the maximal data of the set dtype.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> n = Tensor([4], mstype.int64)
|
||||
>>> n = 4
|
||||
>>> seed = 0
|
||||
>>> offset = 0
|
||||
>>> output = ops.randperm(n, seed, offset, dtype=mstype.int64)
|
||||
>>> print(output)
|
||||
[1 0 2 3]
|
||||
"""
|
||||
if is_cpu_backend():
|
||||
if isinstance(n, int):
|
||||
n = Tensor(n)
|
||||
randperm_ = _get_cache_prim(RandpermV2)(dtype=dtype)
|
||||
return randperm_(n, seed, offset)
|
||||
if isinstance(n, Tensor):
|
||||
n = int(n)
|
||||
randperm_ = _get_cache_prim(P.Randperm)(max_length=n, dtype=dtype)
|
||||
return randperm_(Tensor((n,)))
|
||||
|
||||
|
||||
|
||||
@_function_forbid_reuse
|
||||
|
|
Loading…
Reference in New Issue