forked from mindspore-Ecosystem/mindspore
support arange ops
This commit is contained in:
parent
e0eef48f83
commit
40fadfe6f1
|
@ -354,6 +354,7 @@ Array操作
|
|||
|
||||
mindspore.ops.adaptive_max_pool2d
|
||||
mindspore.ops.affine_grid
|
||||
mindspore.ops.arange
|
||||
mindspore.ops.batch_to_space_nd
|
||||
mindspore.ops.broadcast_to
|
||||
mindspore.ops.col2im
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
mindspore.ops.arange
|
||||
=====================
|
||||
|
||||
.. py:function:: mindspore.ops.arange(start=0, end=None, step=1, *, dtype=None)
|
||||
|
||||
返回从 `start` 开始,步长为 `step` ,且不超过 `end` (不包括 `end` )的序列。
|
||||
|
||||
参数:
|
||||
- **start** (Union[float, int, Tensor]) - 序列中的第一个数字。
|
||||
- **end** (Union[float, int, Tensor]) - 序列的上限或下限,不包含在序列中。
|
||||
- **step** (Union[float, int, Tensor]) - 表述序列中数值的步长。
|
||||
- **dtype** (mindspore.dtype, 可选) - 返回序列的数据类型。默认值:None。如果未指定或者为None,将会被推断为 `start` 、 `end` 和 `step` 参数中精度最高的类型。
|
||||
|
||||
返回:
|
||||
一维Tensor,数据类型与输入数据类型一致。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `start` , `end` , `step` 既不是int也不是float也不是在支持类型中的TensorScalar。
|
||||
- **ValueError** - `step` 等于0。
|
||||
- **ValueError** - `start` 小于等于 `end` , `step` 小于0。
|
||||
- **ValueError** - `start` 大于等于 `end` , `step` 大于0。
|
|
@ -354,6 +354,7 @@ Array Operation
|
|||
|
||||
mindspore.ops.adaptive_max_pool2d
|
||||
mindspore.ops.affine_grid
|
||||
mindspore.ops.arange
|
||||
mindspore.ops.batch_to_space_nd
|
||||
mindspore.ops.broadcast_to
|
||||
mindspore.ops.col2im
|
||||
|
|
|
@ -96,6 +96,7 @@ from .array_func import (
|
|||
space_to_batch_nd,
|
||||
batch_to_space_nd,
|
||||
range,
|
||||
arange,
|
||||
select,
|
||||
one_hot,
|
||||
matrix_diag,
|
||||
|
|
|
@ -15,6 +15,9 @@
|
|||
|
||||
"""Operators for function."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import builtins
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
|
@ -117,6 +120,107 @@ def _check_attr_dtype(param_name, input_dtype, allow_dtypes, cls_name):
|
|||
##############################
|
||||
|
||||
|
||||
def _cast_type(x, to_type):
|
||||
"""cast input to the specified type or cast input to tensor"""
|
||||
if isinstance(x, Tensor):
|
||||
x = cast_(x, to_type)
|
||||
else:
|
||||
x = scalar_to_tensor_(x, to_type)
|
||||
return x
|
||||
|
||||
|
||||
def _get_type(x):
|
||||
"""get the dtype of input"""
|
||||
if isinstance(x, Tensor):
|
||||
return x.dtype
|
||||
return type(x)
|
||||
|
||||
|
||||
def _get_max_type(start, end, step):
|
||||
"""get max input type with `level`"""
|
||||
valid_dtypes = [mstype.int32, mstype.float32, mstype.int64, mstype.float64]
|
||||
arg_map = [start, end, step]
|
||||
arg_type_map = [str(_get_type(i)) for i in arg_map]
|
||||
for arg_value in arg_map:
|
||||
if not (isinstance(arg_value, (float, int))
|
||||
or (isinstance(arg_value, Tensor) and arg_value.dtype in valid_dtypes)):
|
||||
raise TypeError(
|
||||
f"For arange, the input type must be int or float or a TensorScalar in {valid_dtypes},"
|
||||
f" but got {_get_type(arg_value)}")
|
||||
|
||||
type_map = {'Float64': '3', 'Float32': '2', "<class 'float'>": '2', 'Int64': '1', "<class 'int'>": '1',
|
||||
'Int32': '0'}
|
||||
type_map_reverse = {'3': mstype.float64, '2': mstype.float32, '1': mstype.int64, '0': mstype.int32}
|
||||
type_level = [type_map.get(i) for i in arg_type_map]
|
||||
max_level = builtins.max(type_level)
|
||||
return type_map_reverse.get(max_level)
|
||||
|
||||
|
||||
def arange(start=0, end=None, step=1, *, dtype=None):
|
||||
r"""
|
||||
Creates a sequence of numbers that begins at `start` and extends by increments of
|
||||
`step` up to but not including `end`.
|
||||
|
||||
Args:
|
||||
start (Union[float, int, Tensor]): The first number in the sequence.
|
||||
end (Union[float, int, Tensor]): Upper or lower limit of the sequence, exclusive.
|
||||
step (Union[float, int, Tensor]): Number that increments `start`.
|
||||
dtype (mindspore.dtype, optional): The desired data type of returned tensor. Default: None.
|
||||
If dtype is not given or None, the dtype is inferred to be the type with the highest precision among
|
||||
the `start`, `end` and `step` parameters.
|
||||
|
||||
Returns:
|
||||
A 1-D Tensor, with the same type as the inputs.
|
||||
|
||||
Raises:
|
||||
TypeError: If `start`, `end` or `step` is not an int or a float or a TensorScalar in valid dtypes.
|
||||
ValueError: If `step` = 0.
|
||||
ValueError: If `start` >= `end` when `step` > 0.
|
||||
ValueError: If `start` <= `end` when `step` < 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> import mindspore.ops as ops
|
||||
>>> output = arange(1, 6)
|
||||
>>> print(output)
|
||||
>>> print(output.dtype)
|
||||
[1 2 3 4 5]
|
||||
Int64
|
||||
>>> output = arange(0, 3, 1.2)
|
||||
>>> print(output)
|
||||
>>> print(output.dtype)
|
||||
[0. 1.2 2.4]
|
||||
Float32
|
||||
>>> output = arange(7, 1, -2)
|
||||
>>> print(output)
|
||||
>>> print(output.dtype)
|
||||
[7 5 3]
|
||||
Int64
|
||||
>>> output = arange(ms.Tensor(12.0, dtype=ms.float64), 2, ms.Tensor(-1.0, dtype=ms.float32))
|
||||
>>> print(output)
|
||||
>>> print(output.dtype)
|
||||
[12. 11. 10. 9. 8. 7. 6. 5. 4. 3.]
|
||||
Float64
|
||||
"""
|
||||
if end is None:
|
||||
start, end = 0, start
|
||||
max_type = _get_max_type(start, end, step)
|
||||
start = _cast_type(start, max_type)
|
||||
end = _cast_type(end, max_type)
|
||||
step = _cast_type(step, max_type)
|
||||
|
||||
if start.shape != () or end.shape != () or step.shape != ():
|
||||
raise ValueError(f"For arange, the input args must be a TensorScalar,"
|
||||
f" but got start shape:{start.shape}, end shape:{end.shape}, step shape:{step.shape}")
|
||||
data = P.Range()(start, end, step)
|
||||
if dtype is not None:
|
||||
data = cast_(data, dtype)
|
||||
return data
|
||||
|
||||
|
||||
def eye(n, m, t):
|
||||
"""
|
||||
Creates a tensor with ones on the diagonal and zeros in the rest.
|
||||
|
@ -5063,6 +5167,7 @@ __all__ = [
|
|||
'dyn_shape',
|
||||
'rank',
|
||||
'range',
|
||||
'arange',
|
||||
'reshape',
|
||||
'reshape_',
|
||||
'flatten',
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, start=0, end=None, step=1, dtype=None):
|
||||
return ops.arange(start, end, step, dtype=dtype)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_arange_normal(mode):
|
||||
"""
|
||||
Feature: arange
|
||||
Description: Verify the result of arange
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
output1 = net(1, 6)
|
||||
output2 = net(10.0, dtype=ms.int32)
|
||||
output3 = net(ms.Tensor(12.0, dtype=ms.float64), 2, ms.Tensor(-1.0, dtype=ms.float32))
|
||||
assert np.allclose(output1.asnumpy(), np.array([1., 2., 3., 4., 5.]))
|
||||
assert output1.dtype == ms.int64
|
||||
assert np.allclose(output2.asnumpy(), np.array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]))
|
||||
assert output2.dtype == ms.int32
|
||||
assert np.allclose(output3.asnumpy(), np.array([12., 11., 10., 9., 8., 7., 6., 5., 4., 3.]))
|
||||
assert output3.dtype == ms.float64
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
test arange api
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, start=0, end=None, step=1, dtype=None):
|
||||
return ops.arange(start, end, step, dtype=dtype)
|
||||
|
||||
|
||||
def test_arange_normal():
|
||||
"""
|
||||
Feature: arange
|
||||
Description: Test the functionality of arange
|
||||
Expectation: success
|
||||
"""
|
||||
net = Net()
|
||||
_cell_graph_executor.compile(net, 1, 6)
|
Loading…
Reference in New Issue