!49463 support randperm function
Merge pull request !49463 from 吕昱峰(Nate.River)/randperm
This commit is contained in:
commit
be9ceffb0a
|
@ -470,6 +470,7 @@ Tensor创建
|
||||||
mindspore.ops.randn
|
mindspore.ops.randn
|
||||||
mindspore.ops.randn_like
|
mindspore.ops.randn_like
|
||||||
mindspore.ops.random_poisson
|
mindspore.ops.random_poisson
|
||||||
|
mindspore.ops.randperm
|
||||||
mindspore.ops.standard_laplace
|
mindspore.ops.standard_laplace
|
||||||
mindspore.ops.standard_normal
|
mindspore.ops.standard_normal
|
||||||
mindspore.ops.uniform
|
mindspore.ops.uniform
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
mindspore.ops.randperm
|
||||||
|
========================
|
||||||
|
|
||||||
|
.. py:function:: mindspore.ops.randperm(n, seed=0, offset=0, dtype=mstype.int64)
|
||||||
|
|
||||||
|
生成从 0 到 n-1 的整数随机排列。
|
||||||
|
|
||||||
|
返回由 n 推断出的具有确定形状的张量,其中的随机数取自给定类型可以表示的数据范围。
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
`n` 必须大于0。
|
||||||
|
|
||||||
|
返回一个Tensor,shape和dtype由输入决定,其元素为服从标准正态分布的 :math:`[0, 1)` 区间的数字。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **n** (Union[Tensor, int]) - 输入大小,如果为Tensor,则形状为()或(1,),数据类型为int64。
|
||||||
|
- **seed** (int,可选): 随机种子。 默认值:0。当seed为-1(只有负值)时,offset为0,由时间决定。
|
||||||
|
- **offset** (int,可选): 优先级高于随机种子。 默认值:0。必须是非负数。
|
||||||
|
- **dtype** (:class:`mindspore.dtype`,可选):输出的类型。必须是以下类型之一:int32、int16、int8、uint8、int64、float64、float32、float16。 默认值:int64。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Tensor,shape由参数 `n` 决定,dtype由参数 `dtype` 决定。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **TypeError** - 如果 `dtype` 不是一个 `mstype.float_type` 类型。
|
||||||
|
- **ValueError** - 如果 `n` 是负数或0。
|
||||||
|
- **ValueError** - 如果 `seed` 不是非负整数。
|
||||||
|
- **ValueError** - 如果 `n` 是超过指定数据类型的最大范围。
|
|
@ -470,6 +470,7 @@ Randomly Generating Functions
|
||||||
mindspore.ops.randn
|
mindspore.ops.randn
|
||||||
mindspore.ops.randn_like
|
mindspore.ops.randn_like
|
||||||
mindspore.ops.random_poisson
|
mindspore.ops.random_poisson
|
||||||
|
mindspore.ops.randperm
|
||||||
mindspore.ops.standard_laplace
|
mindspore.ops.standard_laplace
|
||||||
mindspore.ops.standard_normal
|
mindspore.ops.standard_normal
|
||||||
mindspore.ops.uniform
|
mindspore.ops.uniform
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
from mindspore import context
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops.primitive import constexpr, _primexpr
|
from mindspore.ops.primitive import constexpr, _primexpr
|
||||||
|
@ -584,6 +585,11 @@ def choice_with_mask(input_x, count=256, seed=None):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def is_cpu_backend():
|
||||||
|
return context.get_context('device_target') == 'CPU'
|
||||||
|
|
||||||
|
|
||||||
@_function_forbid_reuse
|
@_function_forbid_reuse
|
||||||
def randperm(n, seed=0, offset=0, dtype=mstype.int64):
|
def randperm(n, seed=0, offset=0, dtype=mstype.int64):
|
||||||
r"""
|
r"""
|
||||||
|
@ -612,18 +618,26 @@ def randperm(n, seed=0, offset=0, dtype=mstype.int64):
|
||||||
ValueError: If `n` is larger than the maximal data of the set dtype.
|
ValueError: If `n` is larger than the maximal data of the set dtype.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``CPU``
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> n = Tensor([4], mstype.int64)
|
>>> n = 4
|
||||||
>>> seed = 0
|
>>> seed = 0
|
||||||
>>> offset = 0
|
>>> offset = 0
|
||||||
>>> output = ops.randperm(n, seed, offset, dtype=mstype.int64)
|
>>> output = ops.randperm(n, seed, offset, dtype=mstype.int64)
|
||||||
>>> print(output)
|
>>> print(output)
|
||||||
[1 0 2 3]
|
[1 0 2 3]
|
||||||
"""
|
"""
|
||||||
|
if is_cpu_backend():
|
||||||
|
if isinstance(n, int):
|
||||||
|
n = Tensor(n)
|
||||||
randperm_ = _get_cache_prim(RandpermV2)(dtype=dtype)
|
randperm_ = _get_cache_prim(RandpermV2)(dtype=dtype)
|
||||||
return randperm_(n, seed, offset)
|
return randperm_(n, seed, offset)
|
||||||
|
if isinstance(n, Tensor):
|
||||||
|
n = int(n)
|
||||||
|
randperm_ = _get_cache_prim(P.Randperm)(max_length=n, dtype=dtype)
|
||||||
|
return randperm_(Tensor((n,)))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@_function_forbid_reuse
|
@_function_forbid_reuse
|
||||||
|
|
Loading…
Reference in New Issue