Remove needless tensor api

This commit is contained in:
shaojunsong 2022-12-23 16:13:06 +08:00
parent 2ef887450b
commit 22507e7eb6
13 changed files with 0 additions and 927 deletions

View File

@ -1,13 +0,0 @@
mindspore.Tensor.as_tensor
==========================
.. py:method:: mindspore.Tensor.as_tensor(data, dtype=None)
将数据转换为mindspore中的张量。
参数:
- **data** (array_like) - 张量的初始数据。可以是列表、元组、NumPy.ndarray、标量和其他类型。
- **dtype** (mindspore.dtype, 可选) - 返回张量的所需数据类型。默认值:如果为"None",则从数据推断数据类型。
返回:
Tensor数据类型在mindspore的数据类型中。

View File

@ -1,14 +0,0 @@
mindspore.Tensor.empty_strided
===============================
.. py:method:: mindspore.Tensor.empty_strided(size, stride, dtype=mstype.float64, seed=None)
使用指定的“大小”和“步幅”创建张量,并用未定义的数据填充。
参数:
- **size** (tuple[ints]) - 输出张量的形状。
- **stride** (tuple[ints]) - 输出张量的步幅。
- **dtype** (mindspore.dtype, 可选) - 返回张量的所需数据类型。
返回:
具有指定大小和步幅并填充了未定义数据的张量。

View File

@ -1,15 +0,0 @@
mindspore.Tensor.frombuffer
============================
.. py:method:: mindspore.Tensor.frombuffer(buffer, dtype=mstype.float64, count=-1, offset=0)
从实现Python缓冲区协议的对象创建一维 `Tensor`。跳过缓冲区中的 `offset` 字节,并且取出数据类型为 `dtype``count` 个数据。
参数:
- **buffer** (object) - 公开缓冲区接口的Python对象。
- **dtype** (mindspore.dtype) - 返回张量的所需数据类型。
- **count** (int, 可选) - 要读取的所需元素的数量。如果为负值,将读取所有元素(直到缓冲区结束)。默认值:-1。
- **offset** (int, 可选) - 缓冲区开始时要跳过的字节数。默认值0。
返回:
来自实现Python缓冲协议的对象的一维张量。

View File

@ -1,19 +0,0 @@
mindspore.Tensor.multinomial
=============================
.. py:method:: mindspore.Tensor.multinomial(num_samples, seed=0, seed2=0)
返回从相应的张量输入行。输入行不需要求和为1(在这种情况下,我们使用这些值作为权重)但必须是非负的、有限的并且具有非零和。self必须是输入张量包含概率总和的必须是1或2维。
参数:
- **num_samples** (int32) - 要绘制的样本数。
- **seed** (int) - 随机种子必须为非负数。默认值0。
- **seed2** (int) - 随机seed2必须为非负数。默认值0。
返回:
与self具有相同行的张量每行具有num_samples采样索引。
异常:
- **TypeError** - 如果 `seed``seed2` 都不是int。
- **TypeError** - 如果 `self` 不是数据类型为float32的Tensor。
- **TypeError** - 如果 `num_samples` 的数据类型不是int32。

View File

@ -1,23 +0,0 @@
mindspore.Tensor.poisson
==========================
.. py:method:: mindspore.Tensor.poisson(shape, mean, seed=0, seed2=0)
返回与input大小相同的张量其中每个元素都是从泊松采样的input中相应元素给出的速率参数分布。张量self的数值作为泊松分布的μ参数。
.. math::
\text{out}_i \sim \text{Poisson}(\text{input}_i)out*i*Poisson(input*i*)
参数:
- **shape** (tuple) - 要生成的随机张量的形状。只允许使用常量值。
- **seed** (int, option) - 设置随机种子0到2**32
- **seed2** (int, option) - 将随机seed2设置为0到2**32
返回:
Tensor形状与input_Tensor相同。
异常:
- **TypeError** - 如果 `seed``seed2` 都不是int。
- **TypeError** - 如果 `shape` 不是元组。
- **TypeError** - 如果 `mean` 不是数据类型不是float32的Tensor。

View File

@ -1,15 +0,0 @@
mindspore.Tensor.rand_like
==========================
.. py:method:: mindspore.Tensor.rand_like(seed=None)
返回与填充的输入大小相同的张量,数值为区间[0,1)上均匀分布的随机数。
参数:
- **seed** (int, option) - 设置随机种子(0到2**32)。
返回:
Tensor形状与self相同。
异常:
- **TypeError** - 如果input_sensor的数据类型不是int或float

View File

@ -1,17 +0,0 @@
mindspore.Tensor.randint_like
==============================
.. py:method:: mindspore.Tensor.randint_like(high, low=0, seed=None)
返回与输入张量大小相同的张量,数值为区间[lowhigh]上的随机数如果只输入一个int类型的数据默认值为high如果输入两个整数则分别为low和high。
参数:
- **low** (int, 可选) - 要从分布中提取的最小整数。默认值0。
- **high** (int) - 高于要从分布中提取的最高整数的一个。
- **seed** (int, 可选) - 设置随机种子(0到2**32)。
返回:
Tensor形状与self相同。
异常:
- **TypeError** - 如果input_sensor的数据类型不是int或float。

View File

@ -1,15 +0,0 @@
mindspore.Tensor.randn_like
============================
.. py:method:: mindspore.Tensor.randn_like(seed=None)
返回一个与输入大小相同的张量该张量由均值为0、方差为1的正态分布中的随机数填充。
参数:
- **seed** (int, 可选) - 设置随机种子(0到2**32)。
返回:
Tensor形状与self相同。
异常:
- **TypeError** - 如果self的数据类型不是int或float。

View File

@ -1,19 +0,0 @@
mindspore.Tensor.randperm
==========================
.. py:method:: mindspore.Tensor.randperm(max_length=1, pad=-1)
生成从0到n-1的n个随机样本不重复。如果 `max_length` >n最后的 `maxlength-n` 元素将填充 `pad`
参数:
- **max_length** (int) - 预期获取的项数该数字必须大于0。默认值1。
- **pad** (int) - 要填充的pad值。默认值-1。
- **dtype** (mindspore.dtype) - 输出的类型。默认值mindspore.int32。
返回:
Tensorshape为(`max_length`,),类型为:`dtype`
异常:
- **TypeError** - 如果 `max_length``pad` 不是int。
- **TypeError** - 如果 `self` 有非int元素。
- **TypeError** - 如果 `self` 有负数元素。

View File

@ -513,149 +513,6 @@ class Tensor(Tensor_):
return Tensor(Tensor_.from_numpy(array))
@staticmethod
def frombuffer(buffer, dtype=mstype.float64, count=-1, offset=0):
r"""
Creates a 1-dimensional :class:`Tensor` from an object that implements
the Python buffer protocol.
Skips the first :attr:`offset` bytes in the buffer, and interprets the rest of
the raw bytes as a 1-dimensional tensor of type :attr:`dtype` with :attr:`count`
elements.
Args:
buffer (object): a Python object that exposes the buffer interface.
dtype (mindspore.dtype): the desired data type of returned tensor.
count (int, optional): the number of desired elements to be read. If negative,
all the elements (until the end of the buffer) will be read. Default: -1.
offset (int, optional): the number of bytes to skip at the start of the buffer. Default: 0.
Returns:
a 1-dimensional Tensor from an object that implements the Python buffer protocol.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from array import array
>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> input_array = array("d", [1, 2, 3, 4])
>>> input_array
array('d', [1.0, 2.0, 3.0, 4.0])
>>> output = Tensor.frombuffer(input_array, mindspore.int32)
>>> print(output)
[1 2 3 4]
"""
res = np.frombuffer(buffer=buffer, dtype=np.float64, count=count, offset=offset)
result = Tensor(res, dtype=dtype)
return result
@staticmethod
def empty_strided(size, stride, dtype=mstype.float64, seed=None):
r"""
Creates a tensor with the specified :attr:`size` and :attr:`stride` and filled with undefined data.
Args:
size (tuple of python:ints): the shape of the output tensor.
stride (tuple of python:ints): the strides of the output tensor.
dtype (mindspore.dtype, optional): the desired data type of returned tensor.
Returns:
a tensor with the specified size and stride and filled with undefined data.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import Tensor
>>> size = (3, 3)
>>> stride = (1, 3)
>>> output = Tensor.empty_strided(size, stride, seed = 0)
>>> print(output)
[[0.00000000e+00 7.15189366e+10 0.00000000e+00]
[0.00000000e+00 0.00000000e+00 6.45894113e+10]
[0.00000000e+00 8.91773001e+10 9.63662761e+10]]
"""
np.random.seed(seed)
tensor_ = Tensor(np.random.uniform(low=0, high=10e10, size=size))
tensor_array = tensor_.asnumpy()
stride_tensor = tensor_.as_strided(shape=size, strides=stride)
stride_array = stride_tensor.asnumpy()
stride_array.resize(len(stride_array) * len(stride_array[0]))
for i in range(size[0]):
for j in range(size[1]):
if not sum(stride_array - tensor_array[i][j]) < 0.01:
tensor_array[i][j] = 0.0
return Tensor(tensor_array, dtype=dtype)
@staticmethod
def poisson(shape, mean, seed=0, seed2=0):
r"""
Returns a tensor of the same size as `input` with each element sampled from a Poisson
distribution with rate parameter given by the corresponding element in `input` i.e.,
\text{out}_i \sim \text{Poisson}(\text{input}_i)out*i*Poisson(input*i*),
and self as a tensor is the μ parameter .the distribution was constructed with.
The parameter defines mean number of occurrences of the event.
It must be greater than 0. With float32 data type.
Args:
seed (int, option): set the random seed (0 to 2**32)
seed2 (int, option): set the random seed2 (0 to 2**32)
Inputs:
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
Returns:
out (Union[Tensor, int]), with the same shape as input_tensor.
Raises:
TypeError: If neither `seed` nor `seed2` is an int.
TypeError: If `shape` is not a tuple.
TypeError: If `mean` is not a Tensor whose dtype is not float32.
Supported Platforms:
``Ascend``
Examples:
>>> shape = (4, 1)
>>> mean = Tensor(np.array([5.0, 10.0]), mstype.float32)
>>> output = Tensor.Poisson(shape, mean, seed=5)
>>> result = output.shape
>>> print(result)
(4, 2)
"""
return tensor_operator_registry.get('poisson')(seed, seed2)(shape, mean)
@staticmethod
def as_tensor(data, dtype=None):
r"""
convert data to tensor in mindspore.
Args:
data (array_like): Initial data for the tensor. Can be a list, tuple,
NumPy ndarray, scalar, and other types.
dtype (mindspore.dtype, optional): the desired data type of returned tensor.
Default: if None, infers data type from data.
Returns:
Tensor contains the data and the dtype is in mindspore.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> from mindspore import Tensor
>>> input_data = np.array([1, 2, 3])
>>> ms_tensor = Tensor.as_tensor(input_data)
>>> ms_tensor
Tensor(shape=[3], dtype=Int64, value= [1, 2, 3])
"""
return Tensor(data, dtype=dtype)
@staticmethod
def _use_logical_kernel(me, other) -> bool:
"""
@ -3222,170 +3079,6 @@ class Tensor(Tensor_):
validator.check_is_int(seed, 'seed')
return tensor_operator_registry.get('bernoulli')(self, p, seed)
def multinomial(self, num_samples, seed=0, seed2=0):
r"""
Returns a tensor sampled from the multinomial probability distribution located in the corresponding
row of tensor input.
Note:
The rows of input do not need to sum to one (in which case we use the values as weights),
but must be non-negative, finite and have a non-zero sum. self must be the input tensor
containing the cumsum of probabilities, must be 1 or 2 dimensions.
Args:
seed (int): Random seed, must be non-negative. Default: 0.
seed2 (int): Random seed2, must be non-negative. Default: 0.
Inputs:
- **num_samples** (int32) - number of samples to draw.
Outputs:
Tensor with the same rows as `self`, each row has num_samples sampled indices.
Raises:
TypeError: If neither `seed` nor `seed2` is an int.
TypeError: If `self` is not a Tensor whose dtype is float32.
TypeError: If dtype of `num_samples` is not int32.
Supported Platforms:
``GPU``
Examples:
>>> from mindspore import Tensor
>>> import mindspore
>>> x = Tensor([0., 9., 4., 0.], mindspore.float32)
>>> output = x.multinomial(num_samples=2,seed=10)
>>> print(output)
[2 1]
"""
self._init_check()
validator.check_non_negative_int(seed, 'seed')
validator.check_non_negative_int(seed2, 'seed')
return tensor_operator_registry.get('multinomial')(seed, seed2)(self, num_samples)
def rand_like(self, seed=None):
r"""
Returns a tensor with the same size as input that is filled with
random numbers from a uniform distribution on the interval [0, 1)
Args:
input_tensor (Union[Tensor, int, float]): the input tensor.
seed (int, option): set the random seed (0 to 2**32).
Returns:
out (Union[Tensor, float]), with the same shape as input_tensor.
Raises:
TypeError: If dtype of the input_tensor is not int or float.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor
>>> input_x = Tensor(np.array([[1, 2, 3, 9], [1, 2, 3, 9]]), mindspore.int8)
>>> output = input_x.rand_like(seed = 0)
>>> print(output)
[[0.5488135 0.71518937 0.60276338 0.54488318]
[0.4236548 0.64589411 0.43758721 0.891773 ]]
>>> input_p = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
>>> output = input_p.rand_like(seed = 0)
>>> print(output)
[0.5488135 0.71518937 0.60276338]
"""
input_tensor = self
input_tensor = np.array(input_tensor)
shape_ = input_tensor.shape
input_tensor = input_tensor.reshape(-1)
x = len(input_tensor)
np.random.seed(seed)
return Tensor(np.array([np.random.rand(1) for i in range(x)]).reshape(shape_))
def randint_like(self, high, low=0, seed=None):
r"""
Returns a tensor with the same size as the input tensor,
and the numerical value is a random number on the interval [low, high],
if only one int type data is entered, the default value is high,
if two integers are entered, they are low and high respectively.
Args:
input_tensor (Union[Tensor, int, float]): the size of input will determine size of the output tensor.
low (int, optional) Lowest integer to be drawn from the distribution. Default: 0.
high (int) One above the highest integer to be drawn from the distribution.
seed (int, optional): set the random seed (0 to 2**32).
Returns:
out (Union[Tensor, int]), with the same shape as input_tensor.
Raises:
TypeError: If dtype of the input_tensor is not int or float.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor
>>> input_x = Tensor(np.array([1., 2., 3., 4., 5.]), mindspore.float32)
>>> output = input_x.randint_like(20, seed = 0)
>>> print(output)
[12 15 0 3 3]
>>> output = input_x.randint_like(20, 100, seed = 0)
>>> print(output)
[64 67 84 87 87]
"""
input_tensor = self
input_tensor = np.array(input_tensor)
shape_ = input_tensor.shape
input_tensor = input_tensor.reshape(-1)
if low > high:
high, low = low, high
x = len(input_tensor)
np.random.seed(seed)
return Tensor(np.array([np.random.randint(low, high) for i in range(x)]).reshape(shape_))
def randn_like(self, seed=None):
r"""
Returns a tensor with the same size as input that is filled with random
numbers from a normal distribution with mean 0 and variance 1.
Args:
input_tensor (Union[Tensor, int, float]): the size of input will determine size of the output tensor.
seed (int, optional): set the random seed (0 to 2**32).
Returns:
out (Union[Tensor, int]), with the same shape as input_tensor.
Raises:
TypeError: If dtype of the input_tensor is not int or float.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor
>>> input_x = Tensor(np.array([1., 2., 3., 4., 5.]), mindspore.float32)
>>> output = input_x.randn_like(seed = 0)
>>> print(output)
[1.7640524 0.4001572 0.978738 2.2408931 1.867558 ]
>>> input_p = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), mindspore.int32)
>>> output = input_p.randn_like(seed = 0)
>>> print(output)
[[ 1.7640524 0.4001572 0.978738 2.2408931 1.867558 ]
[-0.9772779 0.95008844 -0.1513572 -0.10321885 0.41059852]]
"""
input_tensor = np.array(self)
shape_ = input_tensor.shape
input_tensor = input_tensor.reshape(-1)
x = len(input_tensor)
np.random.seed(seed)
return Tensor([np.random.randn() for i in range(x)]).reshape(shape_)
def as_strided(self, shape=None, strides=None, subok=False, writeable=True):
r"""
as_strided(input, size, stride, storage_offset=0) -> Tensor
@ -3419,43 +3112,6 @@ class Tensor(Tensor_):
strides = tuple(np.array(strides) * n)
return Tensor(np.lib.stride_tricks.as_strided(x, shape, strides, subok, writeable), dtype=dtype_)
def randperm(self, max_length=1, pad=-1):
r"""
Generates n random samples from 0 to n-1 without repeating. If `max_length` > n,
the last `max_length-n` elements will be filled with `pad`.
Args:
max_length (int): Number of items expected to get and the number must be greater than 0. Default: 1.
pad (int): The pad value to be filled. Default: -1.
dtype (mindspore.dtype): The type of output. Default: mindspore.int32.
Inputs:
- **n** (Tensor[int32]) - The input tensor with shape: (1,) and the number must be in [0, `max_length`].
Outputs:
- **output** (Tensor) - The output Tensor with shape: (`max_length`,) and type: `dtype`.
Raises:
TypeError: If neither `max_length` nor `pad` is an int.
TypeError: If `self` has non-Int elements.
TypeError: If `self` has negative elements.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> # The result of every execution is different because this operator will generate n random samples.
>>> from mindspore import Tensor
>>> import mindspore
>>> n = Tensor([20], dtype=mindspore.int32)
>>> output = n.randperm(max_length=30, pad=-1)
>>> print(output)
[15 6 11 19 14 16 9 5 13 18 4 10 8 0 17 2 1 12 3 7
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
"""
self._init_check()
return tensor_operator_registry.get('randperm')(max_length, pad)(self)
def random_categorical(self, num_sample, seed=0, dtype=mstype.int64):
r"""
For details, please refer to :func:`mindspore.ops.random_categorical`.

View File

@ -1,159 +0,0 @@
from array import array
import numpy as np
import pytest
from mindspore import context, Tensor
import mindspore.common.dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_poisson():
"""
Feature: poisson
Description: test the function of poisson.
Expectation: success
"""
shape = (4, 1)
mean = Tensor(np.array([5.0, 10.0]), mstype.float32)
output = Tensor.poisson(shape, mean, seed=5)
result = output.shape
assert result == (4, 2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_rand_like():
"""
Feature: rand_like
Description: test the function of rand_like.
Expectation: success
"""
input_x = Tensor(np.array([[1, 2, 3, 9], [1, 2, 3, 9]]), mstype.int32)
output = input_x.rand_like(seed=0)
expect_res = np.array([[5.48813504e-01, 7.15189366e-01, 6.02763376e-01, 5.44883183e-01],
[4.23654799e-01, 6.45894113e-01, 4.37587211e-01, 8.91773001e-01]]).astype(np.float64)
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_randint_like():
"""
Feature: randint_like
Description: test the function of randint_like.
Expectation: success
"""
input_x = Tensor(np.array([1., 2., 3., 4., 5.]), mstype.float32)
output = input_x.randint_like(20, 100, seed=0)
expect_res = np.array([64, 67, 84, 87, 87])
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_randn_like():
"""
Feature: randn_like
Description: test the function of randn_like.
Expectation: success
"""
input_p = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), mstype.int32)
output = input_p.randn_like(seed=0)
expect_res = np.array([[1.76405239e+00, 4.00157213e-01, 9.78738010e-01, 2.24089313e+00, 1.86755800e+00],
[-9.77277875e-01, 9.50088441e-01, -1.51357204e-01, -1.03218853e-01, 4.10598516e-01]])
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_as_tensor():
"""
Feature: as_tensor
Description: test the function of as_tensor.
Expectation: success
"""
input_data = np.array([1, 2, 3])
ms_tensor = Tensor.as_tensor(input_data)
expect_res = np.array([1, 2, 3])
assert np.allclose(ms_tensor.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_as_strided():
"""
Feature: rand_like
Description: test the function of as_stride.
Expectation: success
"""
input_array = np.arange(9, dtype=np.int32).reshape(3, 3)
output = Tensor(input_array).as_strided((2, 2), (1, 1))
expect_res = np.array([[0, 1], [1, 2]]).astype(np.int32)
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_frombuffer():
"""
Feature: rand_like
Description: test the function of frombuffer.
Expectation: success
"""
input_array = array("d", [1, 2, 3, 4])
output = Tensor.frombuffer(input_array, mstype.int32)
expect_res = np.array([1, 2, 3, 4]).astype(np.int32)
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_empty_strided():
"""
Feature: rand_like
Description: test the function of empty_strided.
Expectation: success
"""
size = (3, 3)
stride = (1, 3)
output = Tensor.empty_strided(size, stride, seed=0)
expect_res = np.array([[0.00000000e+00, 7.15189366e+10, 0.00000000e+00],
[0.00000000e+00, 0.00000000e+00, 6.45894113e+10],
[0.00000000e+00, 8.91773001e+10, 9.63662761e+10]]).astype(np.float64)
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_randpermn():
"""
Feature: rand_like
Description: test the function of randpermn.
Expectation: success
"""
n = Tensor([20], dtype=mstype.int32)
output = n.randperm(max_length=30, pad=-1)
set_a = set(np.linspace(0, 19, 20).astype(int))
set_b = set(output.asnumpy())
assert set_a.issubset(set_b)

View File

@ -1,125 +0,0 @@
from array import array
import numpy as np
import pytest
import mindspore.common.dtype as mstype
from mindspore import context, Tensor
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
def test_rand_like():
"""
Feature: rand_like
Description: test the function of rand_like.
Expectation: success
"""
input_x = Tensor(np.array([[1, 2, 3, 9], [1, 2, 3, 9]]), mstype.int32)
output = input_x.rand_like(seed=0)
expect_res = np.array([[5.48813504e-01, 7.15189366e-01, 6.02763376e-01, 5.44883183e-01],
[4.23654799e-01, 6.45894113e-01, 4.37587211e-01, 8.91773001e-01]]).astype(np.float64)
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
def test_randint_like():
"""
Feature: randint_like
Description: test the function of randint_like.
Expectation: success
"""
input_x = Tensor(np.array([1., 2., 3., 4., 5.]), mstype.float32)
output = input_x.randint_like(20, 100, seed=0)
expect_res = np.array([64, 67, 84, 87, 87])
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
def test_randn_like():
"""
Feature: randn_like
Description: test the function of randn_like.
Expectation: success
"""
input_p = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), mstype.int32)
output = input_p.randn_like(seed=0)
expect_res = np.array([[1.76405239e+00, 4.00157213e-01, 9.78738010e-01, 2.24089313e+00, 1.86755800e+00],
[-9.77277875e-01, 9.50088441e-01, -1.51357204e-01, -1.03218853e-01, 4.10598516e-01]])
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
def test_as_tensor():
"""
Feature: as_tensor
Description: test the function of as_tensor.
Expectation: success
"""
input_data = np.array([1, 2, 3])
ms_tensor = Tensor.as_tensor(input_data)
expect_res = np.array([1, 2, 3])
assert np.allclose(ms_tensor.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
def test_as_strided():
"""
Feature: rand_like
Description: test the function of as_stride.
Expectation: success
"""
input_array = np.arange(9, dtype=np.int32).reshape(3, 3)
output = Tensor(input_array).as_strided((2, 2), (1, 1))
expect_res = np.array([[0, 1], [1, 2]]).astype(np.int32)
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
def test_frombuffer():
"""
Feature: rand_like
Description: test the function of frombuffer.
Expectation: success
"""
input_array = array("d", [1, 2, 3, 4])
output = Tensor.frombuffer(input_array, mstype.int32)
expect_res = np.array([1, 2, 3, 4]).astype(np.int32)
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
def test_empty_strided():
"""
Feature: rand_like
Description: test the function of empty_strided.
Expectation: success
"""
size = (3, 3)
stride = (1, 3)
output = Tensor.empty_strided(size, stride, seed=0)
expect_res = np.array([[0.00000000e+00, 7.15189366e+10, 0.00000000e+00],
[0.00000000e+00, 0.00000000e+00, 6.45894113e+10],
[0.00000000e+00, 8.91773001e+10, 9.63662761e+10]]).astype(np.float64)
assert np.allclose(output.asnumpy(), expect_res)

View File

@ -1,149 +0,0 @@
from array import array
import numpy as np
import pytest
import mindspore.common.dtype as mstype
from mindspore import context, Tensor
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_multinomial():
"""
Feature: multinomial
Description: test the function of multinomial.
Expectation: success
"""
input_tensor = Tensor([0., 9., 4., 0.], mstype.float32)
output = input_tensor.multinomial(num_samples=2, seed=10)
expect_res = np.array([2, 1])
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_rand_like():
"""
Feature: rand_like
Description: test the function of rand_like.
Expectation: success
"""
input_x = Tensor(np.array([[1, 2, 3, 9], [1, 2, 3, 9]]), mstype.int8)
output = input_x.rand_like(seed=0)
expect_res = np.array([[5.48813504e-01, 7.15189366e-01, 6.02763376e-01, 5.44883183e-01],
[4.23654799e-01, 6.45894113e-01, 4.37587211e-01, 8.91773001e-01]]).astype(np.float64)
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_randint_like():
"""
Feature: randint_like
Description: test the function of randint_like.
Expectation: success
"""
input_x = Tensor(np.array([1., 2., 3., 4., 5.]), mstype.float32)
output = input_x.randint_like(20, 100, seed=0)
expect_res = np.array([64, 67, 84, 87, 87])
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_randn_like():
"""
Feature: randn_like
Description: test the function of randn_like.
Expectation: success
"""
input_p = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), mstype.int32)
output = input_p.randn_like(seed=0)
expect_res = np.array([[1.76405239e+00, 4.00157213e-01, 9.78738010e-01, 2.24089313e+00, 1.86755800e+00],
[-9.77277875e-01, 9.50088441e-01, -1.51357204e-01, -1.03218853e-01, 4.10598516e-01]])
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_as_tensor():
"""
Feature: as_tensor
Description: test the function of as_tensor.
Expectation: success
"""
input_data = np.array([1, 2, 3])
ms_tensor = Tensor.as_tensor(input_data)
expect_res = np.array([1, 2, 3])
assert np.allclose(ms_tensor.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_as_strided():
"""
Feature: rand_like
Description: test the function of as_stride.
Expectation: success
"""
input_array = np.arange(9, dtype=np.int32).reshape(3, 3)
output = Tensor(input_array).as_strided((2, 2), (1, 1))
expect_res = np.array([[0, 1], [1, 2]]).astype(np.int32)
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_frombuffer():
"""
Feature: rand_like
Description: test the function of frombuffer.
Expectation: success
"""
input_array = array("d", [1, 2, 3, 4])
output = Tensor.frombuffer(input_array, mstype.int32)
expect_res = np.array([1, 2, 3, 4]).astype(np.int32)
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_empty_strided():
"""
Feature: rand_like
Description: test the function of empty_strided.
Expectation: success
"""
size = (3, 3)
stride = (1, 3)
output = Tensor.empty_strided(size, stride, seed=0)
expect_res = np.array([[0.00000000e+00, 7.15189366e+10, 0.00000000e+00],
[0.00000000e+00, 0.00000000e+00, 6.45894113e+10],
[0.00000000e+00, 8.91773001e+10, 9.63662761e+10]]).astype(np.float64)
assert np.allclose(output.asnumpy(), expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_randpermn():
"""
Feature: rand_like
Description: test the function of randpermn.
Expectation: success
"""
n = Tensor([20], dtype=mstype.int32)
output = n.randperm(max_length=30, pad=-1)
set_a = set(np.linspace(0, 19, 20).astype(int))
set_b = set(output.asnumpy())
assert set_a.issubset(set_b)