Added new Tensor APIs and Sts.
This commit is contained in:
parent
7876890d27
commit
f3162d8726
|
@ -71,6 +71,8 @@ mindspore.Tensor
|
|||
mindspore.Tensor.bitwise_or
|
||||
mindspore.Tensor.bitwise_xor
|
||||
mindspore.Tensor.ceil
|
||||
mindspore.Tensor.cholesky
|
||||
mindspore.Tensor.cholesky_inverse
|
||||
mindspore.Tensor.cosh
|
||||
mindspore.Tensor.erf
|
||||
mindspore.Tensor.erfc
|
||||
|
@ -240,7 +242,13 @@ Array操作
|
|||
|
||||
mindspore.Tensor.asnumpy
|
||||
mindspore.Tensor.astype
|
||||
mindspore.Tensor.bool
|
||||
mindspore.Tensor.float
|
||||
mindspore.Tensor.from_numpy
|
||||
mindspore.Tensor.half
|
||||
mindspore.Tensor.int
|
||||
mindspore.Tensor.long
|
||||
mindspore.Tensor.to
|
||||
mindspore.Tensor.to_coo
|
||||
mindspore.Tensor.to_csr
|
||||
|
||||
|
|
|
@ -76,6 +76,8 @@ Element-wise Methods
|
|||
mindspore.Tensor.bitwise_or
|
||||
mindspore.Tensor.bitwise_xor
|
||||
mindspore.Tensor.ceil
|
||||
mindspore.Tensor.cholesky
|
||||
mindspore.Tensor.cholesky_inverse
|
||||
mindspore.Tensor.cosh
|
||||
mindspore.Tensor.erf
|
||||
mindspore.Tensor.erfc
|
||||
|
@ -242,7 +244,13 @@ Type Conversion
|
|||
|
||||
mindspore.Tensor.asnumpy
|
||||
mindspore.Tensor.astype
|
||||
mindspore.Tensor.bool
|
||||
mindspore.Tensor.float
|
||||
mindspore.Tensor.from_numpy
|
||||
mindspore.Tensor.half
|
||||
mindspore.Tensor.int
|
||||
mindspore.Tensor.long
|
||||
mindspore.Tensor.to
|
||||
mindspore.Tensor.to_coo
|
||||
mindspore.Tensor.to_csr
|
||||
|
||||
|
|
|
@ -315,6 +315,14 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"atanh", std::string("atanh")}, // atanh()
|
||||
{"bmm", std::string("bmm")}, // bmm()
|
||||
{"value", std::string("value_")}, // P.Load(param, U)
|
||||
{"to", std::string("to")}, // to()
|
||||
{"bool", std::string("to_bool")}, // bool()
|
||||
{"float", std::string("to_float")}, // float()
|
||||
{"half", std::string("to_half")}, // half()
|
||||
{"int", std::string("to_int")}, // int()
|
||||
{"long", std::string("to_long")}, // long()
|
||||
{"cholesky", std::string("cholesky")}, // cholesky()
|
||||
{"cholesky_inverse", std::string("cholesky_inverse")}, // cholesky_inverse()
|
||||
}},
|
||||
{kObjectTypeRowTensorType,
|
||||
{
|
||||
|
|
|
@ -3125,7 +3125,7 @@ def atanh(x):
|
|||
|
||||
def bmm(input_x, mat2):
|
||||
r"""
|
||||
Computes matrix multiplication between two tensors by batch.
|
||||
Computes matrix multiplication between two tensors by batch.
|
||||
"""
|
||||
return F.bmm(input_x, mat2)
|
||||
|
||||
|
@ -3135,3 +3135,59 @@ def value_(x):
|
|||
Get the value of Parameter or Tensor x. If x is Parameter, will change the type from RefTensor to Tensor.
|
||||
"""
|
||||
return P.Load()(x, monad.U)
|
||||
|
||||
|
||||
def to(input_x, dtype):
|
||||
r"""
|
||||
Performs tensor dtype conversion.
|
||||
"""
|
||||
return P.Cast()(input_x, dtype)
|
||||
|
||||
|
||||
def to_bool(input_x):
|
||||
r"""
|
||||
Converts input tensor dtype to bool.
|
||||
"""
|
||||
return P.Cast()(input_x, mstype.bool_)
|
||||
|
||||
|
||||
def to_float(input_x):
|
||||
r"""
|
||||
Converts input tensor dtype to float32.
|
||||
"""
|
||||
return P.Cast()(input_x, mstype.float32)
|
||||
|
||||
|
||||
def to_half(input_x):
|
||||
r"""
|
||||
Converts input tensor dtype to float16.
|
||||
"""
|
||||
return P.Cast()(input_x, mstype.float16)
|
||||
|
||||
|
||||
def to_int(input_x):
|
||||
r"""
|
||||
Converts input tensor dtype to int32.
|
||||
"""
|
||||
return P.Cast()(input_x, mstype.int32)
|
||||
|
||||
|
||||
def to_long(input_x):
|
||||
r"""
|
||||
Converts input tensor dtype to int64.
|
||||
"""
|
||||
return P.Cast()(input_x, mstype.int64)
|
||||
|
||||
|
||||
def cholesky(input_x, upper=False):
|
||||
r"""
|
||||
Computes the Cholesky decomposition of a symmetric positive-definite matrix
|
||||
"""
|
||||
return F.cholesky(input_x, upper=upper)
|
||||
|
||||
|
||||
def cholesky_inverse(input_x, upper=False):
|
||||
r"""
|
||||
Computes the inverse of the positive definite matrix using cholesky matrix factorization.
|
||||
"""
|
||||
return F.cholesky_inverse(input_x, upper=upper)
|
||||
|
|
|
@ -5594,7 +5594,7 @@ class Tensor(Tensor_):
|
|||
Args:
|
||||
mat2 (Tensor): The tensor to be multiplied. The shape of the tensor is :math:`(*B, C, M)`.
|
||||
|
||||
Outputs:
|
||||
Returns:
|
||||
Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
|
||||
|
||||
Raises:
|
||||
|
@ -5622,6 +5622,218 @@ class Tensor(Tensor_):
|
|||
return tensor_operator_registry.get('bmm')(self, mat2)
|
||||
|
||||
|
||||
def to(self, dtype):
|
||||
r"""
|
||||
Performs tensor dtype conversion.
|
||||
|
||||
Args:
|
||||
dtype (dtype.Number): The valid data type of the output tensor. Only constant value is allowed.
|
||||
|
||||
Returns:
|
||||
Tensor, converted to the specified `dtype`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `dtype` is not a Number.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||
>>> input_x = Tensor(input_np)
|
||||
>>> dtype = mindspore.int32
|
||||
>>> output = input_x.to(dtype)
|
||||
>>> print(output.dtype)
|
||||
Int32
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('to')()(self, dtype)
|
||||
|
||||
|
||||
def bool(self):
|
||||
r"""
|
||||
Converts input tensor dtype to `bool`.
|
||||
|
||||
Returns:
|
||||
Tensor, converted to the `bool` dtype.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.ones([2,2]), mindspore.float32)
|
||||
>>> output = input_x.bool()
|
||||
>>> print(output.dtype)
|
||||
Bool
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('bool')()(self, mstype.bool_)
|
||||
|
||||
|
||||
def float(self):
|
||||
r"""
|
||||
Converts input tensor dtype to `float32`.
|
||||
|
||||
Returns:
|
||||
Tensor, converted to the `float32` dtype.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.ones([2,2]), mindspore.int32)
|
||||
>>> output = input_x.float()
|
||||
>>> print(output.dtype)
|
||||
Float32
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('float')()(self, mstype.float32)
|
||||
|
||||
|
||||
def half(self):
|
||||
r"""
|
||||
Converts input tensor dtype to `float16`.
|
||||
|
||||
Returns:
|
||||
Tensor, converted to the `float16` dtype.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.ones([2,2]), mindspore.int32)
|
||||
>>> output = input_x.half()
|
||||
>>> print(output.dtype)
|
||||
Float16
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('half')()(self, mstype.float16)
|
||||
|
||||
|
||||
def int(self):
|
||||
r"""
|
||||
Converts input tensor dtype to `int32`.
|
||||
|
||||
Returns:
|
||||
Tensor, converted to the `int32` dtype.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.ones([2,2]), mindspore.float32)
|
||||
>>> output = input_x.int()
|
||||
>>> print(output.dtype)
|
||||
Int32
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('int')()(self, mstype.int32)
|
||||
|
||||
|
||||
def long(self):
|
||||
r"""
|
||||
Converts input tensor dtype to `int64`.
|
||||
|
||||
Returns:
|
||||
Tensor, converted to the `int64` dtype.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.ones([2,2]), mindspore.int32)
|
||||
>>> output = input_x.long()
|
||||
>>> print(output.dtype)
|
||||
Int64
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('long')()(self, mstype.int64)
|
||||
|
||||
|
||||
def cholesky(self, upper=False):
|
||||
r"""
|
||||
Computes the Cholesky decomposition of a symmetric positive-definite matrix :math:`A`
|
||||
or for batches of symmetric positive-definite matrices.
|
||||
|
||||
If `upper` is `True`, the returned matrix :math:`U` is upper-triangular, and the decomposition has the form:
|
||||
|
||||
.. math::
|
||||
A = U^TU
|
||||
|
||||
If `upper` is `False`, the returned matrix :math:`L` is lower-triangular, and the decomposition has the form:
|
||||
|
||||
.. math::
|
||||
A = LL^T
|
||||
|
||||
Args:
|
||||
upper (bool): Flag that indicates whether to return a upper or lower triangular matrix.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape and data type as input tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If `upper` is not a bool.
|
||||
TypeError: If dtype of tensor is not one of: float64, float32.
|
||||
ValueError: If tensor is not batch square.
|
||||
ValueError: If tensor is not symmetric positive definite.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[1.0, 1.0], [1.0, 2.0]]), mindspore.float32)
|
||||
>>> output = x.cholesky(upper=False)
|
||||
>>> print(output)
|
||||
[[1. 0.]
|
||||
[1. 1.]]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('cholesky')(upper=upper)(self)
|
||||
|
||||
|
||||
def cholesky_inverse(self, upper=False):
|
||||
r"""
|
||||
Returns the inverse of the positive definite matrix using cholesky matrix factorization.
|
||||
|
||||
If `upper` is `False`, :math:`U` is a lower triangular such that the output tensor is
|
||||
|
||||
.. math::
|
||||
inv = (UU^{T})^{{-1}}
|
||||
|
||||
If `upper` is `True`, :math:`U` is an upper triangular such that the output tensor is
|
||||
|
||||
.. math::
|
||||
inv = (U^{T}U)^{{-1}}
|
||||
|
||||
Note:
|
||||
The tensor must be either an upper triangular matrix or a lower triangular matrix.
|
||||
|
||||
Args:
|
||||
upper(bool): Whether to return a lower or upper triangular matrix. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape and dtype as input tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of input tensor is not one of: float32, float64.
|
||||
ValueError: If the dimension of input tensor is not equal to 2.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[2,0,0], [4,1,0], [-1,1,2]]), mindspore.float32)
|
||||
>>> output = x.cholesky_inverse()
|
||||
>>> print(output)
|
||||
[[ 5.8125 -2.625 0.625 ]
|
||||
[-2.625 1.25 -0.25 ]
|
||||
[ 0.625 -0.25 0.25 ]]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('cholesky_inverse')(upper=upper)(self)
|
||||
|
||||
|
||||
class RowTensor(RowTensor_):
|
||||
"""
|
||||
A sparse representation of a set of tensor slices at given indices.
|
||||
|
|
|
@ -492,5 +492,13 @@ tensor_operator_registry.register('argmin_with_value', min)
|
|||
tensor_operator_registry.register('coo_add', sparse_add)
|
||||
tensor_operator_registry.register('top_k', P.TopK)
|
||||
tensor_operator_registry.register('isfinite', P.IsFinite)
|
||||
tensor_operator_registry.register('to', P.Cast)
|
||||
tensor_operator_registry.register('bool', P.Cast)
|
||||
tensor_operator_registry.register('float', P.Cast)
|
||||
tensor_operator_registry.register('half', P.Cast)
|
||||
tensor_operator_registry.register('int', P.Cast)
|
||||
tensor_operator_registry.register('long', P.Cast)
|
||||
tensor_operator_registry.register('cholesky', P.Cholesky)
|
||||
tensor_operator_registry.register('cholesky_inverse', P.CholeskyInverse)
|
||||
__all__ = [name for name in dir() if name[0] != "_"]
|
||||
__all__.remove('Primitive')
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_bool_tensor_api():
|
||||
"""
|
||||
Feature: test bool tensor API.
|
||||
Description: test bool dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be bool.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]), ms.float32)
|
||||
output = x.bool()
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == ms.bool_
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bool_tensor_api_modes():
|
||||
"""
|
||||
Feature: test bool tensor API for different modes.
|
||||
Description: test bool dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be bool.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_bool_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
test_bool_tensor_api()
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_float_tensor_api():
|
||||
"""
|
||||
Feature: test float tensor API.
|
||||
Description: test float dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be float32.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]), ms.int32)
|
||||
output = x.float()
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == ms.float32
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_float_tensor_api_modes():
|
||||
"""
|
||||
Feature: test float tensor API for different modes.
|
||||
Description: test float dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be float32.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_float_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
test_float_tensor_api()
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_half_tensor_api():
|
||||
"""
|
||||
Feature: test half tensor API.
|
||||
Description: test half dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be float16.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]), ms.float32)
|
||||
output = x.half()
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == ms.float16
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_half_tensor_api_modes():
|
||||
"""
|
||||
Feature: test half tensor API for different modes.
|
||||
Description: test half dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be float16.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_half_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
test_half_tensor_api()
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_int_tensor_api():
|
||||
"""
|
||||
Feature: test int tensor API.
|
||||
Description: test int dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be int32.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]), ms.float32)
|
||||
output = x.int()
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == ms.int32
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_int_tensor_api_modes():
|
||||
"""
|
||||
Feature: test int tensor API for different modes.
|
||||
Description: test int dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be int32.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_int_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
test_int_tensor_api()
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_long_tensor_api():
|
||||
"""
|
||||
Feature: test long tensor API.
|
||||
Description: test long dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be int64.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]), ms.float32)
|
||||
output = x.long()
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == ms.int64
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_long_tensor_api_modes():
|
||||
"""
|
||||
Feature: test long tensor API for different modes.
|
||||
Description: test long dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be int64.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_long_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
test_long_tensor_api()
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_to_tensor_api(dtype):
|
||||
"""
|
||||
Feature: test to tensor API.
|
||||
Description: test to API for dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be same as op arg.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]))
|
||||
output = x.to(dtype)
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == dtype
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_to_tensor_api_modes():
|
||||
"""
|
||||
Feature: test to tensor API for different modes.
|
||||
Description: test to API for dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be same as op arg.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_to_tensor_api(ms.bool_)
|
||||
test_to_tensor_api(ms.float16)
|
||||
test_to_tensor_api(ms.float32)
|
||||
test_to_tensor_api(ms.float64)
|
||||
test_to_tensor_api(ms.int8)
|
||||
test_to_tensor_api(ms.uint8)
|
||||
test_to_tensor_api(ms.int16)
|
||||
test_to_tensor_api(ms.uint16)
|
||||
test_to_tensor_api(ms.int32)
|
||||
test_to_tensor_api(ms.int64)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
test_to_tensor_api(ms.bool_)
|
||||
test_to_tensor_api(ms.float16)
|
||||
test_to_tensor_api(ms.float32)
|
||||
test_to_tensor_api(ms.float64)
|
||||
test_to_tensor_api(ms.int8)
|
||||
test_to_tensor_api(ms.uint8)
|
||||
test_to_tensor_api(ms.int16)
|
||||
test_to_tensor_api(ms.uint16)
|
||||
test_to_tensor_api(ms.int32)
|
||||
test_to_tensor_api(ms.int64)
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_bool_tensor_api():
|
||||
"""
|
||||
Feature: test bool tensor API.
|
||||
Description: test bool dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be bool.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]), ms.float32)
|
||||
output = x.bool()
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == ms.bool_
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_bool_tensor_api_modes():
|
||||
"""
|
||||
Feature: test bool tensor API for different modes.
|
||||
Description: test bool dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be bool.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
test_bool_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
test_bool_tensor_api()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bool_tensor_api_modes()
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_float_tensor_api():
|
||||
"""
|
||||
Feature: test float tensor API.
|
||||
Description: test float dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be float32.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]), ms.int32)
|
||||
output = x.float()
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == ms.float32
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_float_tensor_api_modes():
|
||||
"""
|
||||
Feature: test float tensor API for different modes.
|
||||
Description: test float dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be float32.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
test_float_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
test_float_tensor_api()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_float_tensor_api_modes()
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_half_tensor_api():
|
||||
"""
|
||||
Feature: test half tensor API.
|
||||
Description: test half dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be float16.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]), ms.float32)
|
||||
output = x.half()
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == ms.float16
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_half_tensor_api_modes():
|
||||
"""
|
||||
Feature: test half tensor API for different modes.
|
||||
Description: test half dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be float16.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
test_half_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
test_half_tensor_api()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_half_tensor_api_modes()
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_int_tensor_api():
|
||||
"""
|
||||
Feature: test int tensor API.
|
||||
Description: test int dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be int32.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]), ms.float32)
|
||||
output = x.int()
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == ms.int32
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_int_tensor_api_modes():
|
||||
"""
|
||||
Feature: test int tensor API for different modes.
|
||||
Description: test int dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be int32.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
test_int_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
test_int_tensor_api()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_int_tensor_api_modes()
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_long_tensor_api():
|
||||
"""
|
||||
Feature: test long tensor API.
|
||||
Description: test long dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be int64.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]), ms.float32)
|
||||
output = x.long()
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == ms.int64
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_long_tensor_api_modes():
|
||||
"""
|
||||
Feature: test long tensor API for different modes.
|
||||
Description: test long dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be int64.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
test_long_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
test_long_tensor_api()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_long_tensor_api_modes()
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_to_tensor_api(dtype):
|
||||
"""
|
||||
Feature: test to tensor API.
|
||||
Description: test to API for dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be same as op arg.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]))
|
||||
output = x.to(dtype)
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == dtype
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_to_tensor_api_modes():
|
||||
"""
|
||||
Feature: test to tensor API for different modes.
|
||||
Description: test to API for dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be same as op arg.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
test_to_tensor_api(ms.bool_)
|
||||
test_to_tensor_api(ms.float16)
|
||||
test_to_tensor_api(ms.float32)
|
||||
test_to_tensor_api(ms.float64)
|
||||
test_to_tensor_api(ms.int8)
|
||||
test_to_tensor_api(ms.uint8)
|
||||
test_to_tensor_api(ms.int16)
|
||||
test_to_tensor_api(ms.uint16)
|
||||
test_to_tensor_api(ms.int32)
|
||||
test_to_tensor_api(ms.int64)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
test_to_tensor_api(ms.bool_)
|
||||
test_to_tensor_api(ms.float16)
|
||||
test_to_tensor_api(ms.float32)
|
||||
test_to_tensor_api(ms.float64)
|
||||
test_to_tensor_api(ms.int8)
|
||||
test_to_tensor_api(ms.uint8)
|
||||
test_to_tensor_api(ms.int16)
|
||||
test_to_tensor_api(ms.uint16)
|
||||
test_to_tensor_api(ms.int32)
|
||||
test_to_tensor_api(ms.int64)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_to_tensor_api_modes()
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_bool_tensor_api():
|
||||
"""
|
||||
Feature: test bool tensor API.
|
||||
Description: test bool dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be bool.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]), ms.float32)
|
||||
output = x.bool()
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == ms.bool_
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bool_tensor_api_modes():
|
||||
"""
|
||||
Feature: test bool tensor API for different modes.
|
||||
Description: test bool dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be bool.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_bool_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
test_bool_tensor_api()
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_float_tensor_api():
|
||||
"""
|
||||
Feature: test float tensor API.
|
||||
Description: test float dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be float32.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]), ms.int32)
|
||||
output = x.float()
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == ms.float32
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_float_tensor_api_modes():
|
||||
"""
|
||||
Feature: test float tensor API for different modes.
|
||||
Description: test float dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be float32.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_float_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
test_float_tensor_api()
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_half_tensor_api():
|
||||
"""
|
||||
Feature: test half tensor API.
|
||||
Description: test half dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be float16.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]), ms.float32)
|
||||
output = x.half()
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == ms.float16
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_half_tensor_api_modes():
|
||||
"""
|
||||
Feature: test half tensor API for different modes.
|
||||
Description: test half dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be float16.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_half_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
test_half_tensor_api()
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_int_tensor_api():
|
||||
"""
|
||||
Feature: test int tensor API.
|
||||
Description: test int dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be int32.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]), ms.float32)
|
||||
output = x.int()
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == ms.int32
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_int_tensor_api_modes():
|
||||
"""
|
||||
Feature: test int tensor API for different modes.
|
||||
Description: test int dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be int32.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_int_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
test_int_tensor_api()
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_long_tensor_api():
|
||||
"""
|
||||
Feature: test long tensor API.
|
||||
Description: test long dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be int64.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]), ms.float32)
|
||||
output = x.long()
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == ms.int64
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_long_tensor_api_modes():
|
||||
"""
|
||||
Feature: test long tensor API for different modes.
|
||||
Description: test long dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be int64.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_long_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
test_long_tensor_api()
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def test_to_tensor_api(dtype):
|
||||
"""
|
||||
Feature: test to tensor API.
|
||||
Description: test to API for dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be same as op arg.
|
||||
"""
|
||||
dtype_op = P.DType()
|
||||
x = Tensor(np.ones([2, 3, 1]))
|
||||
output = x.to(dtype)
|
||||
assert x.shape == output.shape
|
||||
assert dtype_op(output) == dtype
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_to_tensor_api_modes():
|
||||
"""
|
||||
Feature: test to tensor API for different modes.
|
||||
Description: test to API for dtype tensor conversion.
|
||||
Expectation: the input and output shape should be same. output dtype should be same as op arg.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_to_tensor_api(ms.bool_)
|
||||
test_to_tensor_api(ms.float16)
|
||||
test_to_tensor_api(ms.float32)
|
||||
test_to_tensor_api(ms.float64)
|
||||
test_to_tensor_api(ms.int8)
|
||||
test_to_tensor_api(ms.uint8)
|
||||
test_to_tensor_api(ms.int16)
|
||||
test_to_tensor_api(ms.uint16)
|
||||
test_to_tensor_api(ms.int32)
|
||||
test_to_tensor_api(ms.int64)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
test_to_tensor_api(ms.bool_)
|
||||
test_to_tensor_api(ms.float16)
|
||||
test_to_tensor_api(ms.float32)
|
||||
test_to_tensor_api(ms.float64)
|
||||
test_to_tensor_api(ms.int8)
|
||||
test_to_tensor_api(ms.uint8)
|
||||
test_to_tensor_api(ms.int16)
|
||||
test_to_tensor_api(ms.uint16)
|
||||
test_to_tensor_api(ms.int32)
|
||||
test_to_tensor_api(ms.int64)
|
Loading…
Reference in New Issue