From bae0d1f79d37b383d72cbcff67af9753183bd1a1 Mon Sep 17 00:00:00 2001 From: alashkari Date: Tue, 31 Jan 2023 16:27:38 -0500 Subject: [PATCH] Updated ops.div & Tensor.div. Added ops.lstsq, Tensor.lstsq, Tensor.mvlgamma, Tensor.matmul, Tensor.maximum, Tensor.ne & Tensor.neg to docs. Added ops.unfold & Tensor.unfold STs. Updated ops.argmax. --- docs/api/api_python/mindspore.ops.rst | 1 + .../api_python/mindspore/mindspore.Tensor.rst | 3 ++ docs/api/api_python_en/Tensor_list.rst | 3 ++ docs/api/api_python_en/mindspore.ops.rst | 1 + .../_extends/parse/standard_method.py | 8 ++-- mindspore/python/mindspore/common/tensor.py | 8 ++-- .../mindspore/ops/function/array_func.py | 38 +++++++-------- .../mindspore/ops/function/math_func.py | 4 +- tests/st/ops/ascend/test_argmax.py | 29 +++++++----- tests/st/ops/cpu/test_argmax_op.py | 15 +++--- tests/st/ops/cpu/test_cholesky_op.py | 45 ++++++++++++++++++ tests/st/ops/cpu/test_div_op.py | 10 ++-- tests/st/ops/cpu/test_unfold_op.py | 46 ++++++++++++++++++ tests/st/ops/gpu/test_argmax_op.py | 38 ++++++++------- tests/st/ops/gpu/test_div_op.py | 10 ++-- tests/st/ops/gpu/test_unfold_op.py | 46 ++++++++++++++++++ tests/st/tensor/test_cholesky.py | 47 +++++++++++++++++++ tests/st/tensor/test_cholesky_inverse.py | 47 +++++++++++++++++++ tests/st/tensor/test_unfold.py | 47 +++++++++++++++++++ 19 files changed, 369 insertions(+), 77 deletions(-) create mode 100644 tests/st/ops/cpu/test_cholesky_op.py create mode 100644 tests/st/ops/cpu/test_unfold_op.py create mode 100644 tests/st/ops/gpu/test_unfold_op.py create mode 100644 tests/st/tensor/test_cholesky.py create mode 100644 tests/st/tensor/test_cholesky_inverse.py create mode 100644 tests/st/tensor/test_unfold.py diff --git a/docs/api/api_python/mindspore.ops.rst b/docs/api/api_python/mindspore.ops.rst index 4d5a970139a..a1264f60cf9 100644 --- a/docs/api/api_python/mindspore.ops.rst +++ b/docs/api/api_python/mindspore.ops.rst @@ -393,6 +393,7 @@ Reduction函数 mindspore.ops.inverse mindspore.ops.ger mindspore.ops.kron + mindspore.ops.lstsq mindspore.ops.matmul mindspore.ops.matrix_solve mindspore.ops.matrix_exp diff --git a/docs/api/api_python/mindspore/mindspore.Tensor.rst b/docs/api/api_python/mindspore/mindspore.Tensor.rst index 7ed3317b7b4..f5c428f2e85 100644 --- a/docs/api/api_python/mindspore/mindspore.Tensor.rst +++ b/docs/api/api_python/mindspore/mindspore.Tensor.rst @@ -186,6 +186,7 @@ mindspore.Tensor mindspore.Tensor.lt mindspore.Tensor.masked_fill mindspore.Tensor.masked_select + mindspore.Tensor.matmul mindspore.Tensor.matrix_power mindspore.Tensor.max mindspore.Tensor.maximum @@ -201,6 +202,7 @@ mindspore.Tensor mindspore.Tensor.mT mindspore.Tensor.mul mindspore.Tensor.multiply + mindspore.Tensor.mvlgamma mindspore.Tensor.nan_to_num mindspore.Tensor.nansum mindspore.Tensor.narrow @@ -208,6 +210,7 @@ mindspore.Tensor mindspore.Tensor.ndim mindspore.Tensor.ndimension mindspore.Tensor.ne + mindspore.Tensor.neg mindspore.Tensor.negative mindspore.Tensor.nelement mindspore.Tensor.new_ones diff --git a/docs/api/api_python_en/Tensor_list.rst b/docs/api/api_python_en/Tensor_list.rst index dc2870ac64e..87384e5a5b6 100644 --- a/docs/api/api_python_en/Tensor_list.rst +++ b/docs/api/api_python_en/Tensor_list.rst @@ -192,6 +192,7 @@ mindspore.Tensor.lt mindspore.Tensor.masked_fill mindspore.Tensor.masked_select + mindspore.Tensor.matmul mindspore.Tensor.matrix_power mindspore.Tensor.max mindspore.Tensor.maximum @@ -207,6 +208,7 @@ mindspore.Tensor.mT mindspore.Tensor.mul mindspore.Tensor.multiply + mindspore.Tensor.mvlgamma mindspore.Tensor.nan_to_num mindspore.Tensor.nansum mindspore.Tensor.narrow @@ -214,6 +216,7 @@ mindspore.Tensor.ndim mindspore.Tensor.ndimension mindspore.Tensor.ne + mindspore.Tensor.neg mindspore.Tensor.negative mindspore.Tensor.nelement mindspore.Tensor.new_ones diff --git a/docs/api/api_python_en/mindspore.ops.rst b/docs/api/api_python_en/mindspore.ops.rst index b97777f5261..663773d9928 100644 --- a/docs/api/api_python_en/mindspore.ops.rst +++ b/docs/api/api_python_en/mindspore.ops.rst @@ -393,6 +393,7 @@ Linear Algebraic Functions mindspore.ops.inverse mindspore.ops.ger mindspore.ops.kron + mindspore.ops.lstsq mindspore.ops.matmul mindspore.ops.matrix_solve mindspore.ops.matrix_exp diff --git a/mindspore/python/mindspore/_extends/parse/standard_method.py b/mindspore/python/mindspore/_extends/parse/standard_method.py index 4711e8aaa00..229a5eb6fd4 100644 --- a/mindspore/python/mindspore/_extends/parse/standard_method.py +++ b/mindspore/python/mindspore/_extends/parse/standard_method.py @@ -3364,7 +3364,7 @@ def true_divide(divident, divisor): r""" Computes the element-wise division of input tensors. """ - return F.div(divident, divisor, None) + return F.div(divident, divisor, rounding_mode=None) # pylint: disable=redefined-outer-name @@ -4011,11 +4011,11 @@ def multiply(input, other): return F.multiply(input, other) -def div(input, other, rounding_mode=None): +def div(input, value, *, rounding_mode=None): r""" - Divides the tensor `input` by the given input tensor `other` in floating-point type element-wise. + Divides the tensor `input` by the given input tensor `value` in floating-point type element-wise. """ - return F.div(input, other, rounding_mode) + return F.div(input, value, rounding_mode=rounding_mode) def equal(x, y): diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index 7f4a3dd84aa..487e42fb7b1 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -828,7 +828,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta): For details, please refer to :func:`mindspore.ops.div`. """ self._init_check() - return tensor_operator_registry.get('div')(self, value, None) + return tensor_operator_registry.get('div')(self, value, rounding_mode=None) def triu(self, diagonal=0): r""" @@ -3620,19 +3620,19 @@ class Tensor(Tensor_, metaclass=_TensorMeta): self._init_check() return tensor_operator_registry.get('multiply')(self, value) - def div(self, other, rounding_mode=None): + def div(self, value, *, rounding_mode=None): r""" For details, please refer to :func:`mindspore.ops.div`. """ self._init_check() - return tensor_operator_registry.get('div')(self, other, rounding_mode) + return tensor_operator_registry.get('div')(self, value, rounding_mode=rounding_mode) def divide(self, value, *, rounding_mode=None): r""" Alias for :func:`mindspore.Tensor.div`. """ self._init_check() - return tensor_operator_registry.get('div')(self, value, rounding_mode) + return tensor_operator_registry.get('div')(self, value, rounding_mode=rounding_mode) def equal(self, other): r""" diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index 0982ba1019c..80cced847e7 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -49,7 +49,8 @@ from mindspore.ops.operations.array_ops import ( Lstsq, Mvlgamma, CountNonZero, - Tril + Tril, + Argmax ) from mindspore.ops.operations.array_ops import TensorScatterElements from mindspore.common import Tensor @@ -5477,43 +5478,42 @@ def max(x, axis=0, keep_dims=False): return argmax_with_value_op(x) -def argmax(x, axis=None, keepdims=False): +def argmax(input, dim=None, keepdim=False): """ Return the indices of the maximum values of a tensor across a dimension. Args: - x (Tensor): Input tensor. - axis (Union[int, None], optional): The dimension to reduce. If `axis` is None, - the indices of the maximum value within the flattened input will be returned. - Default: None. - keepdims (bool, optional): Whether the output tensor retains the specified - dimension. Ignored if `axis` is None. Default: False. + input (Tensor): Input tensor. + dim (Union[int, None]): The dimension to reduce. If `dim` is None, the indices of the maximum + value within the flattened input will be returned. Default: None. + keepdim (bool): Whether the output tensor retains the specified + dimension. Ignored if `dim` is None. Default: False. Returns: Tensor, indices of the maximum values across a dimension. Raises: - ValueError: If `axis` is out of range. + ValueError: If `dim` is out of range. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: >>> x = Tensor(np.array([[1, 20, 5], [67, 8, 9], [130, 24, 15]]).astype(np.float32)) - >>> output = ops.argmax(x, axis=-1) + >>> output = ops.argmax(x, dim=-1) >>> print(output) [1 0 0] """ - if x.shape == (): + if input.shape == (): return Tensor(0) - is_axis_none = False - if axis is None: - x = reshape_(x, (-1,)) - axis = 0 - is_axis_none = True - out = P.Argmax(axis, mstype.int64)(x) - if keepdims and not is_axis_none: - out = expand_dims_(out, axis) + is_dim_none = False + if dim is None: + input = reshape_(input, (-1,)) + dim = 0 + is_dim_none = True + out = _get_cache_prim(Argmax)(dim, mstype.int64)(input) + if keepdim and not is_dim_none: + out = expand_dims_(out, dim) return out diff --git a/mindspore/python/mindspore/ops/function/math_func.py b/mindspore/python/mindspore/ops/function/math_func.py index f96174d5458..aef84161778 100644 --- a/mindspore/python/mindspore/ops/function/math_func.py +++ b/mindspore/python/mindspore/ops/function/math_func.py @@ -883,7 +883,7 @@ def multiply(input, other): return tensor_mul(input, other) -def div(input, other, rounding_mode=None): +def div(input, other, *, rounding_mode=None): """ Divides the first input tensor by the second input tensor in floating-point type element-wise. @@ -948,7 +948,7 @@ def divide(x, other, *, rounding_mode=None): Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` """ - return div(x, other, rounding_mode) + return div(x, other, rounding_mode=rounding_mode) def float_power(x, exponent): diff --git a/tests/st/ops/ascend/test_argmax.py b/tests/st/ops/ascend/test_argmax.py index b30fbcd4745..70938da52e3 100644 --- a/tests/st/ops/ascend/test_argmax.py +++ b/tests/st/ops/ascend/test_argmax.py @@ -17,9 +17,11 @@ import pytest import mindspore.context as context import mindspore.nn as nn -from mindspore import Tensor, ops +from mindspore import Tensor from mindspore.common.api import jit from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype context.set_context(device_target="Ascend") @@ -42,24 +44,25 @@ def test_net(): print(output.asnumpy()) -def adaptive_argmax_functional(nptype): - x = Tensor(np.array([[1, 20, 5], [67, 8, 9], [130, 24, 15]]).astype(nptype)) - output = ops.argmax(x, axis=-1) - expected = np.array([1, 0, 0]).astype(np.int32) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) +class ArgmaxFuncNet(nn.Cell): + def construct(self, x): + return F.argmax(x, dim=-1) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard -def test_argmax_float32_functional(): +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_functional_argmax(mode): """ - Feature: test argmax functional api. - Description: test float32 inputs. + Feature: Test argmax functional api. + Description: Test argmax functional api for Graph and PyNative modes. Expectation: the result match with expected result. """ - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - adaptive_argmax_functional(np.float32) - context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") - adaptive_argmax_functional(np.float32) + context.set_context(mode=mode, device_target="Ascend") + x = Tensor([[1, 20, 5], [67, 8, 9], [130, 24, 15]], mstype.float32) + net = ArgmaxFuncNet() + output = net(x) + expect_output = np.array([1, 0, 0]).astype(np.int32) + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/ops/cpu/test_argmax_op.py b/tests/st/ops/cpu/test_argmax_op.py index bc1f356fc40..565df1ce2c5 100644 --- a/tests/st/ops/cpu/test_argmax_op.py +++ b/tests/st/ops/cpu/test_argmax_op.py @@ -22,6 +22,7 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor, ops from mindspore.common import dtype as mstype +from mindspore.ops import functional as F context.set_context(mode=context.GRAPH_MODE, device_target="CPU") @@ -94,7 +95,7 @@ def test_argmax_high_dims(): def adaptive_argmax_functional(nptype): x = Tensor(np.array([[1, 20, 5], [67, 8, 9], [130, 24, 15]]).astype(nptype)) - output = ops.argmax(x, axis=-1) + output = F.argmax(x, dim=-1) expected = np.array([1, 0, 0]).astype(np.int32) np.testing.assert_array_almost_equal(output.asnumpy(), expected) @@ -225,12 +226,12 @@ def test_argmax_functional(): Expectation: the result match with expected result. """ x = Tensor([[1, 3, 2], [4, 6, 5], [7, 9, 8]], mstype.int32) - out_dim_none = ops.argmax(x, axis=None, keepdims=False) - out_dim_0 = ops.argmax(x, axis=0, keepdims=False) - out_dim_1 = ops.argmax(x, axis=1, keepdims=False) - out_dim_none_keepdim = ops.argmax(x, axis=None, keepdims=True) - out_dim_0_keepdim = ops.argmax(x, axis=0, keepdims=True) - out_dim_1_keepdim = ops.argmax(x, axis=1, keepdims=True) + out_dim_none = F.argmax(x, dim=None, keepdim=False) + out_dim_0 = F.argmax(x, dim=0, keepdim=False) + out_dim_1 = F.argmax(x, dim=1, keepdim=False) + out_dim_none_keepdim = F.argmax(x, dim=None, keepdim=True) + out_dim_0_keepdim = F.argmax(x, dim=0, keepdim=True) + out_dim_1_keepdim = F.argmax(x, dim=1, keepdim=True) assert out_dim_none.asnumpy() == 7 assert np.all(out_dim_0.asnumpy() == np.array([2, 2, 2])) diff --git a/tests/st/ops/cpu/test_cholesky_op.py b/tests/st/ops/cpu/test_cholesky_op.py new file mode 100644 index 00000000000..ee1a384c827 --- /dev/null +++ b/tests/st/ops/cpu/test_cholesky_op.py @@ -0,0 +1,45 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import pytest +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +import mindspore.context as context +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype + + +class CholeskyFuncNet(nn.Cell): + def construct(self, x): + return F.cholesky(x, upper=False) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_functional_cholesky(mode): + """ + Feature: Test cholesky functional api. + Description: Test cholesky functional api for Graph and PyNative modes. + Expectation: the result match with expected result. + """ + context.set_context(mode=mode, device_target="CPU") + x = Tensor([[1.0, 1.0], [1.0, 2.0]], mstype.float32) + net = CholeskyFuncNet() + output = net(x) + expect_output = np.array([[1., 0.], [1., 1.]]) + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/ops/cpu/test_div_op.py b/tests/st/ops/cpu/test_div_op.py index a548eb8510e..3b4b7a222be 100644 --- a/tests/st/ops/cpu/test_div_op.py +++ b/tests/st/ops/cpu/test_div_op.py @@ -94,7 +94,7 @@ def test_div_trunc_tensor_api(): [-0.2601, -0.2397, 0.5832, 0.2250], [0.0322, 0.7103, 0.6315, -0.8621]])) y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) - output = x.div(y, 'trunc') + output = x.div(y, rounding_mode='trunc') expected = np.array([[0., -0., -0., 0.], [1., 1., 0., -0.], [-0., 0., -1., -0.], @@ -113,7 +113,7 @@ def test_div_floor_tensor_api(): [-0.2601, -0.2397, 0.5832, 0.2250], [0.0322, 0.7103, 0.6315, -0.8621]])) y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) - output = x.div(y, 'floor') + output = x.div(y, rounding_mode='floor') expected = np.array([[0., -1., -1., 0.], [1., 1., 0., -1.], [-1., 0., -2., -1.], @@ -132,7 +132,7 @@ def test_div_functional_api(): [0.1062, 1.4581, 0.7759, -1.2344], [-0.1830, -0.0313, 1.1908, -1.4757]])) y = Tensor(np.array([0.8032, 0.2930, -0.8113, -0.2308])) - output = F.div(x, y) + output = F.div(x, y, rounding_mode=None) expected = np.array([[-0.4620, -6.6051, 0.5676, 1.2639], [0.2260, -3.4509, -1.2086, 6.8990], [0.1322, 4.9764, -0.9564, 5.3484], @@ -151,7 +151,7 @@ def test_div_trunc_functional_api(): [-0.2601, -0.2397, 0.5832, 0.2250], [0.0322, 0.7103, 0.6315, -0.8621]])) y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) - output = F.div(x, y, 'trunc') + output = F.div(x, y, rounding_mode='trunc') expected = np.array([[0., -0., -0., 0.], [1., 1., 0., -0.], [-0., 0., -1., -0.], @@ -170,7 +170,7 @@ def test_div_floor_functional_api(): [-0.2601, -0.2397, 0.5832, 0.2250], [0.0322, 0.7103, 0.6315, -0.8621]])) y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) - output = F.div(x, y, 'floor') + output = F.div(x, y, rounding_mode='floor') expected = np.array([[0., -1., -1., 0.], [1., 1., 0., -1.], [-1., 0., -2., -1.], diff --git a/tests/st/ops/cpu/test_unfold_op.py b/tests/st/ops/cpu/test_unfold_op.py new file mode 100644 index 00000000000..65af1557c95 --- /dev/null +++ b/tests/st/ops/cpu/test_unfold_op.py @@ -0,0 +1,46 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import pytest +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +import mindspore.context as context +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype + + +class UnfoldFuncNet(nn.Cell): + def construct(self, x): + return F.unfold(x, kernel_size=3, dilation=1, stride=1) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_unfold_functional_api_modes(mode): + """ + Feature: Test unfold functional api. + Description: Test unfold functional api for Graph and PyNative modes. + Expectation: the result match with expected result. + """ + context.set_context(mode=mode, device_target="CPU") + x = Tensor(np.ones((4, 4, 32, 32)), mstype.float32) + net = UnfoldFuncNet() + output = net(x) + expected_shape = (4, 36, 30, 30) + assert output.dtype == x.dtype + assert output.shape == expected_shape diff --git a/tests/st/ops/gpu/test_argmax_op.py b/tests/st/ops/gpu/test_argmax_op.py index 04ae29922e2..8cc5d5b90a6 100644 --- a/tests/st/ops/gpu/test_argmax_op.py +++ b/tests/st/ops/gpu/test_argmax_op.py @@ -22,6 +22,7 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor, ops from mindspore.common import dtype as mstype +from mindspore.ops import functional as F class NetArgmax(nn.Cell): @@ -100,26 +101,27 @@ def test_argmax_high_dims(): assert (ms_output.asnumpy() == np_output).all() -def adaptive_argmax_functional(nptype): - x = Tensor(np.array([[1, 20, 5], [67, 8, 9], [130, 24, 15]]).astype(nptype)) - output = ops.argmax(x, axis=-1) - expected = np.array([1, 0, 0]).astype(np.int32) - np.testing.assert_array_almost_equal(output.asnumpy(), expected) +class ArgmaxFuncNet(nn.Cell): + def construct(self, x): + return F.argmax(x, dim=-1) @pytest.mark.level1 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_argmax_float32_functional(): +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_functional_argmax(mode): """ - Feature: test argmax functional api. - Description: test float32 inputs. + Feature: Test argmax functional api. + Description: Test argmax functional api for Graph and PyNative modes. Expectation: the result match with expected result. """ - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - adaptive_argmax_functional(np.float32) - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - adaptive_argmax_functional(np.float32) + context.set_context(mode=mode, device_target="GPU") + x = Tensor([[1, 20, 5], [67, 8, 9], [130, 24, 15]], mstype.float32) + net = ArgmaxFuncNet() + output = net(x) + expect_output = np.array([1, 0, 0]).astype(np.int32) + assert np.allclose(output.asnumpy(), expect_output) @pytest.mark.level0 @@ -234,12 +236,12 @@ def test_argmax_functional(): Expectation: the result match with expected result. """ x = Tensor([[1, 3, 2], [4, 6, 5], [7, 9, 8]], mstype.int32) - out_dim_none = ops.argmax(x, axis=None, keepdims=False) - out_dim_0 = ops.argmax(x, axis=0, keepdims=False) - out_dim_1 = ops.argmax(x, axis=1, keepdims=False) - out_dim_none_keepdim = ops.argmax(x, axis=None, keepdims=True) - out_dim_0_keepdim = ops.argmax(x, axis=0, keepdims=True) - out_dim_1_keepdim = ops.argmax(x, axis=1, keepdims=True) + out_dim_none = F.argmax(x, dim=None, keepdim=False) + out_dim_0 = F.argmax(x, dim=0, keepdim=False) + out_dim_1 = F.argmax(x, dim=1, keepdim=False) + out_dim_none_keepdim = F.argmax(x, dim=None, keepdim=True) + out_dim_0_keepdim = F.argmax(x, dim=0, keepdim=True) + out_dim_1_keepdim = F.argmax(x, dim=1, keepdim=True) assert out_dim_none.asnumpy() == 7 assert np.all(out_dim_0.asnumpy() == np.array([2, 2, 2])) diff --git a/tests/st/ops/gpu/test_div_op.py b/tests/st/ops/gpu/test_div_op.py index f5ba2e12d11..e7f0517e7c4 100644 --- a/tests/st/ops/gpu/test_div_op.py +++ b/tests/st/ops/gpu/test_div_op.py @@ -159,7 +159,7 @@ def test_div_trunc_tensor_api(): [-0.2601, -0.2397, 0.5832, 0.2250], [0.0322, 0.7103, 0.6315, -0.8621]])) y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) - output = x.div(y, 'trunc') + output = x.div(y, rounding_mode='trunc') expected = np.array([[0., -0., -0., 0.], [1., 1., 0., -0.], [-0., 0., -1., -0.], @@ -178,7 +178,7 @@ def test_div_floor_tensor_api(): [-0.2601, -0.2397, 0.5832, 0.2250], [0.0322, 0.7103, 0.6315, -0.8621]])) y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) - output = x.div(y, 'floor') + output = x.div(y, rounding_mode='floor') expected = np.array([[0., -1., -1., 0.], [1., 1., 0., -1.], [-1., 0., -2., -1.], @@ -197,7 +197,7 @@ def test_div_functional_api(): [0.1062, 1.4581, 0.7759, -1.2344], [-0.1830, -0.0313, 1.1908, -1.4757]])) y = Tensor(np.array([0.8032, 0.2930, -0.8113, -0.2308])) - output = F.div(x, y) + output = F.div(x, y, rounding_mode=None) expected = np.array([[-0.4620, -6.6051, 0.5676, 1.2639], [0.2260, -3.4509, -1.2086, 6.8990], [0.1322, 4.9764, -0.9564, 5.3484], @@ -216,7 +216,7 @@ def test_div_trunc_functional_api(): [-0.2601, -0.2397, 0.5832, 0.2250], [0.0322, 0.7103, 0.6315, -0.8621]])) y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) - output = F.div(x, y, 'trunc') + output = F.div(x, y, rounding_mode='trunc') expected = np.array([[0., -0., -0., 0.], [1., 1., 0., -0.], [-0., 0., -1., -0.], @@ -235,7 +235,7 @@ def test_div_floor_functional_api(): [-0.2601, -0.2397, 0.5832, 0.2250], [0.0322, 0.7103, 0.6315, -0.8621]])) y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389])) - output = F.div(x, y, 'floor') + output = F.div(x, y, rounding_mode='floor') expected = np.array([[0., -1., -1., 0.], [1., 1., 0., -1.], [-1., 0., -2., -1.], diff --git a/tests/st/ops/gpu/test_unfold_op.py b/tests/st/ops/gpu/test_unfold_op.py new file mode 100644 index 00000000000..db3024efa28 --- /dev/null +++ b/tests/st/ops/gpu/test_unfold_op.py @@ -0,0 +1,46 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import pytest +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +import mindspore.context as context +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype + + +class UnfoldFuncNet(nn.Cell): + def construct(self, x): + return F.unfold(x, kernel_size=3, dilation=1, stride=1) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_unfold_functional_api_modes(mode): + """ + Feature: Test unfold functional api. + Description: Test unfold functional api for Graph and PyNative modes. + Expectation: the result match with expected result. + """ + context.set_context(mode=mode, device_target="CPU") + x = Tensor(np.ones((4, 4, 32, 32)), mstype.float32) + net = UnfoldFuncNet() + output = net(x) + expected_shape = (4, 36, 30, 30) + assert output.dtype == x.dtype + assert output.shape == expected_shape diff --git a/tests/st/tensor/test_cholesky.py b/tests/st/tensor/test_cholesky.py new file mode 100644 index 00000000000..214b90ec2f1 --- /dev/null +++ b/tests/st/tensor/test_cholesky.py @@ -0,0 +1,47 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import pytest +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +import mindspore.context as context +from mindspore.common import dtype as mstype + + +class CholeskyTensorNet(nn.Cell): + def construct(self, x): + return x.cholesky(upper=False) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_tensor_cholesky(mode): + """ + Feature: Test cholesky tensor api. + Description: Test cholesky tensor api for Graph and PyNative modes. + Expectation: the result match with expected result. + """ + context.set_context(mode=mode) + x = Tensor([[1.0, 1.0], [1.0, 2.0]], mstype.float32) + net = CholeskyTensorNet() + output = net(x) + expect_output = np.array([[1., 0.], [1., 1.]]) + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/tensor/test_cholesky_inverse.py b/tests/st/tensor/test_cholesky_inverse.py new file mode 100644 index 00000000000..818b84347cb --- /dev/null +++ b/tests/st/tensor/test_cholesky_inverse.py @@ -0,0 +1,47 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import pytest +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +import mindspore.context as context +from mindspore.common import dtype as mstype + + +class CholeskyInverseTensorNet(nn.Cell): + def construct(self, x): + return x.cholesky_inverse() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_tensor_cholesky_inverse(mode): + """ + Feature: Test cholesky_inverse tensor api. + Description: Test cholesky_inverse tensor api for Graph and PyNative modes. + Expectation: the result match with expected result. + """ + context.set_context(mode=mode) + x = Tensor([[2, 0, 0], [4, 1, 0], [-1, 1, 2]], mstype.float32) + net = CholeskyInverseTensorNet() + output = net(x) + expect_output = np.array([[5.8125, -2.625, 0.625], [-2.625, 1.25, -0.25], [0.625, -0.25, 0.25]]) + assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/tensor/test_unfold.py b/tests/st/tensor/test_unfold.py new file mode 100644 index 00000000000..e00add31017 --- /dev/null +++ b/tests/st/tensor/test_unfold.py @@ -0,0 +1,47 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import pytest +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor +import mindspore.context as context +from mindspore.common import dtype as mstype + + +class UnfoldTensorNet(nn.Cell): + def construct(self, x): + return x.unfold(kernel_size=3, dilation=1, stride=1) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_tensor_unfold(mode): + """ + Feature: Test unfold tensor api. + Description: Test unfold tensor api for Graph and PyNative modes. + Expectation: the result match with expected result. + """ + context.set_context(mode=mode) + x = Tensor(np.ones((4, 4, 32, 32)), mstype.float32) + net = UnfoldTensorNet() + output = net(x) + expected_shape = (4, 36, 30, 30) + assert output.dtype == x.dtype + assert output.shape == expected_shape