!48317 [ST][MS][OPS] Fix operation validation issues.

Merge pull request !48317 from alashkari/val-updates
This commit is contained in:
i-robot 2023-03-02 06:35:22 +00:00 committed by Gitee
commit 1c4e24f1b1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
19 changed files with 369 additions and 77 deletions

View File

@ -393,6 +393,7 @@ Reduction函数
mindspore.ops.inverse
mindspore.ops.ger
mindspore.ops.kron
mindspore.ops.lstsq
mindspore.ops.matmul
mindspore.ops.matrix_solve
mindspore.ops.matrix_exp

View File

@ -186,6 +186,7 @@ mindspore.Tensor
mindspore.Tensor.lt
mindspore.Tensor.masked_fill
mindspore.Tensor.masked_select
mindspore.Tensor.matmul
mindspore.Tensor.matrix_power
mindspore.Tensor.max
mindspore.Tensor.maximum
@ -201,6 +202,7 @@ mindspore.Tensor
mindspore.Tensor.mT
mindspore.Tensor.mul
mindspore.Tensor.multiply
mindspore.Tensor.mvlgamma
mindspore.Tensor.nan_to_num
mindspore.Tensor.nansum
mindspore.Tensor.narrow
@ -208,6 +210,7 @@ mindspore.Tensor
mindspore.Tensor.ndim
mindspore.Tensor.ndimension
mindspore.Tensor.ne
mindspore.Tensor.neg
mindspore.Tensor.negative
mindspore.Tensor.nelement
mindspore.Tensor.new_ones

View File

@ -192,6 +192,7 @@
mindspore.Tensor.lt
mindspore.Tensor.masked_fill
mindspore.Tensor.masked_select
mindspore.Tensor.matmul
mindspore.Tensor.matrix_power
mindspore.Tensor.max
mindspore.Tensor.maximum
@ -207,6 +208,7 @@
mindspore.Tensor.mT
mindspore.Tensor.mul
mindspore.Tensor.multiply
mindspore.Tensor.mvlgamma
mindspore.Tensor.nan_to_num
mindspore.Tensor.nansum
mindspore.Tensor.narrow
@ -214,6 +216,7 @@
mindspore.Tensor.ndim
mindspore.Tensor.ndimension
mindspore.Tensor.ne
mindspore.Tensor.neg
mindspore.Tensor.negative
mindspore.Tensor.nelement
mindspore.Tensor.new_ones

View File

@ -393,6 +393,7 @@ Linear Algebraic Functions
mindspore.ops.inverse
mindspore.ops.ger
mindspore.ops.kron
mindspore.ops.lstsq
mindspore.ops.matmul
mindspore.ops.matrix_solve
mindspore.ops.matrix_exp

View File

@ -3365,7 +3365,7 @@ def true_divide(divident, divisor):
r"""
Computes the element-wise division of input tensors.
"""
return F.div(divident, divisor, None)
return F.div(divident, divisor, rounding_mode=None)
# pylint: disable=redefined-outer-name
@ -4012,11 +4012,11 @@ def multiply(input, other):
return F.multiply(input, other)
def div(input, other, rounding_mode=None):
def div(input, value, *, rounding_mode=None):
r"""
Divides the tensor `input` by the given input tensor `other` in floating-point type element-wise.
Divides the tensor `input` by the given input tensor `value` in floating-point type element-wise.
"""
return F.div(input, other, rounding_mode)
return F.div(input, value, rounding_mode=rounding_mode)
def equal(x, y):

View File

@ -827,7 +827,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
For details, please refer to :func:`mindspore.ops.div`.
"""
self._init_check()
return tensor_operator_registry.get('div')(self, value, None)
return tensor_operator_registry.get('div')(self, value, rounding_mode=None)
def triu(self, diagonal=0):
r"""
@ -3670,19 +3670,19 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
self._init_check()
return tensor_operator_registry.get('multiply')(self, value)
def div(self, other, rounding_mode=None):
def div(self, value, *, rounding_mode=None):
r"""
For details, please refer to :func:`mindspore.ops.div`.
"""
self._init_check()
return tensor_operator_registry.get('div')(self, other, rounding_mode)
return tensor_operator_registry.get('div')(self, value, rounding_mode=rounding_mode)
def divide(self, value, *, rounding_mode=None):
r"""
Alias for :func:`mindspore.Tensor.div`.
"""
self._init_check()
return tensor_operator_registry.get('div')(self, value, rounding_mode)
return tensor_operator_registry.get('div')(self, value, rounding_mode=rounding_mode)
def equal(self, other):
r"""

View File

@ -49,7 +49,8 @@ from mindspore.ops.operations.array_ops import (
Lstsq,
Mvlgamma,
CountNonZero,
Tril
Tril,
Argmax
)
from mindspore.ops.operations.array_ops import TensorScatterElements
from mindspore.common import Tensor
@ -5477,43 +5478,42 @@ def max(x, axis=0, keep_dims=False):
return argmax_with_value_op(x)
def argmax(x, axis=None, keepdims=False):
def argmax(input, dim=None, keepdim=False):
"""
Return the indices of the maximum values of a tensor across a dimension.
Args:
x (Tensor): Input tensor.
axis (Union[int, None], optional): The dimension to reduce. If `axis` is None,
the indices of the maximum value within the flattened input will be returned.
Default: None.
keepdims (bool, optional): Whether the output tensor retains the specified
dimension. Ignored if `axis` is None. Default: False.
input (Tensor): Input tensor.
dim (Union[int, None]): The dimension to reduce. If `dim` is None, the indices of the maximum
value within the flattened input will be returned. Default: None.
keepdim (bool): Whether the output tensor retains the specified
dimension. Ignored if `dim` is None. Default: False.
Returns:
Tensor, indices of the maximum values across a dimension.
Raises:
ValueError: If `axis` is out of range.
ValueError: If `dim` is out of range.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([[1, 20, 5], [67, 8, 9], [130, 24, 15]]).astype(np.float32))
>>> output = ops.argmax(x, axis=-1)
>>> output = ops.argmax(x, dim=-1)
>>> print(output)
[1 0 0]
"""
if x.shape == ():
if input.shape == ():
return Tensor(0)
is_axis_none = False
if axis is None:
x = reshape_(x, (-1,))
axis = 0
is_axis_none = True
out = P.Argmax(axis, mstype.int64)(x)
if keepdims and not is_axis_none:
out = expand_dims_(out, axis)
is_dim_none = False
if dim is None:
input = reshape_(input, (-1,))
dim = 0
is_dim_none = True
out = _get_cache_prim(Argmax)(dim, mstype.int64)(input)
if keepdim and not is_dim_none:
out = expand_dims_(out, dim)
return out

View File

@ -883,7 +883,7 @@ def multiply(input, other):
return tensor_mul(input, other)
def div(input, other, rounding_mode=None):
def div(input, other, *, rounding_mode=None):
"""
Divides the first input tensor by the second input tensor in floating-point type element-wise.
@ -948,7 +948,7 @@ def divide(x, other, *, rounding_mode=None):
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
return div(x, other, rounding_mode)
return div(x, other, rounding_mode=rounding_mode)
def float_power(x, exponent):

View File

@ -17,9 +17,11 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, ops
from mindspore import Tensor
from mindspore.common.api import jit
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
context.set_context(device_target="Ascend")
@ -42,24 +44,25 @@ def test_net():
print(output.asnumpy())
def adaptive_argmax_functional(nptype):
x = Tensor(np.array([[1, 20, 5], [67, 8, 9], [130, 24, 15]]).astype(nptype))
output = ops.argmax(x, axis=-1)
expected = np.array([1, 0, 0]).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
class ArgmaxFuncNet(nn.Cell):
def construct(self, x):
return F.argmax(x, dim=-1)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_argmax_float32_functional():
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_functional_argmax(mode):
"""
Feature: test argmax functional api.
Description: test float32 inputs.
Feature: Test argmax functional api.
Description: Test argmax functional api for Graph and PyNative modes.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
adaptive_argmax_functional(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
adaptive_argmax_functional(np.float32)
context.set_context(mode=mode, device_target="Ascend")
x = Tensor([[1, 20, 5], [67, 8, 9], [130, 24, 15]], mstype.float32)
net = ArgmaxFuncNet()
output = net(x)
expect_output = np.array([1, 0, 0]).astype(np.int32)
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -22,6 +22,7 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, ops
from mindspore.common import dtype as mstype
from mindspore.ops import functional as F
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
@ -94,7 +95,7 @@ def test_argmax_high_dims():
def adaptive_argmax_functional(nptype):
x = Tensor(np.array([[1, 20, 5], [67, 8, 9], [130, 24, 15]]).astype(nptype))
output = ops.argmax(x, axis=-1)
output = F.argmax(x, dim=-1)
expected = np.array([1, 0, 0]).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
@ -225,12 +226,12 @@ def test_argmax_functional():
Expectation: the result match with expected result.
"""
x = Tensor([[1, 3, 2], [4, 6, 5], [7, 9, 8]], mstype.int32)
out_dim_none = ops.argmax(x, axis=None, keepdims=False)
out_dim_0 = ops.argmax(x, axis=0, keepdims=False)
out_dim_1 = ops.argmax(x, axis=1, keepdims=False)
out_dim_none_keepdim = ops.argmax(x, axis=None, keepdims=True)
out_dim_0_keepdim = ops.argmax(x, axis=0, keepdims=True)
out_dim_1_keepdim = ops.argmax(x, axis=1, keepdims=True)
out_dim_none = F.argmax(x, dim=None, keepdim=False)
out_dim_0 = F.argmax(x, dim=0, keepdim=False)
out_dim_1 = F.argmax(x, dim=1, keepdim=False)
out_dim_none_keepdim = F.argmax(x, dim=None, keepdim=True)
out_dim_0_keepdim = F.argmax(x, dim=0, keepdim=True)
out_dim_1_keepdim = F.argmax(x, dim=1, keepdim=True)
assert out_dim_none.asnumpy() == 7
assert np.all(out_dim_0.asnumpy() == np.array([2, 2, 2]))

View File

@ -0,0 +1,45 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.context as context
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
class CholeskyFuncNet(nn.Cell):
def construct(self, x):
return F.cholesky(x, upper=False)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_functional_cholesky(mode):
"""
Feature: Test cholesky functional api.
Description: Test cholesky functional api for Graph and PyNative modes.
Expectation: the result match with expected result.
"""
context.set_context(mode=mode, device_target="CPU")
x = Tensor([[1.0, 1.0], [1.0, 2.0]], mstype.float32)
net = CholeskyFuncNet()
output = net(x)
expect_output = np.array([[1., 0.], [1., 1.]])
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -94,7 +94,7 @@ def test_div_trunc_tensor_api():
[-0.2601, -0.2397, 0.5832, 0.2250],
[0.0322, 0.7103, 0.6315, -0.8621]]))
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
output = x.div(y, 'trunc')
output = x.div(y, rounding_mode='trunc')
expected = np.array([[0., -0., -0., 0.],
[1., 1., 0., -0.],
[-0., 0., -1., -0.],
@ -113,7 +113,7 @@ def test_div_floor_tensor_api():
[-0.2601, -0.2397, 0.5832, 0.2250],
[0.0322, 0.7103, 0.6315, -0.8621]]))
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
output = x.div(y, 'floor')
output = x.div(y, rounding_mode='floor')
expected = np.array([[0., -1., -1., 0.],
[1., 1., 0., -1.],
[-1., 0., -2., -1.],
@ -132,7 +132,7 @@ def test_div_functional_api():
[0.1062, 1.4581, 0.7759, -1.2344],
[-0.1830, -0.0313, 1.1908, -1.4757]]))
y = Tensor(np.array([0.8032, 0.2930, -0.8113, -0.2308]))
output = F.div(x, y)
output = F.div(x, y, rounding_mode=None)
expected = np.array([[-0.4620, -6.6051, 0.5676, 1.2639],
[0.2260, -3.4509, -1.2086, 6.8990],
[0.1322, 4.9764, -0.9564, 5.3484],
@ -151,7 +151,7 @@ def test_div_trunc_functional_api():
[-0.2601, -0.2397, 0.5832, 0.2250],
[0.0322, 0.7103, 0.6315, -0.8621]]))
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
output = F.div(x, y, 'trunc')
output = F.div(x, y, rounding_mode='trunc')
expected = np.array([[0., -0., -0., 0.],
[1., 1., 0., -0.],
[-0., 0., -1., -0.],
@ -170,7 +170,7 @@ def test_div_floor_functional_api():
[-0.2601, -0.2397, 0.5832, 0.2250],
[0.0322, 0.7103, 0.6315, -0.8621]]))
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
output = F.div(x, y, 'floor')
output = F.div(x, y, rounding_mode='floor')
expected = np.array([[0., -1., -1., 0.],
[1., 1., 0., -1.],
[-1., 0., -2., -1.],

View File

@ -0,0 +1,46 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.context as context
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
class UnfoldFuncNet(nn.Cell):
def construct(self, x):
return F.unfold(x, kernel_size=3, dilation=1, stride=1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_unfold_functional_api_modes(mode):
"""
Feature: Test unfold functional api.
Description: Test unfold functional api for Graph and PyNative modes.
Expectation: the result match with expected result.
"""
context.set_context(mode=mode, device_target="CPU")
x = Tensor(np.ones((4, 4, 32, 32)), mstype.float32)
net = UnfoldFuncNet()
output = net(x)
expected_shape = (4, 36, 30, 30)
assert output.dtype == x.dtype
assert output.shape == expected_shape

View File

@ -22,6 +22,7 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, ops
from mindspore.common import dtype as mstype
from mindspore.ops import functional as F
class NetArgmax(nn.Cell):
@ -100,26 +101,27 @@ def test_argmax_high_dims():
assert (ms_output.asnumpy() == np_output).all()
def adaptive_argmax_functional(nptype):
x = Tensor(np.array([[1, 20, 5], [67, 8, 9], [130, 24, 15]]).astype(nptype))
output = ops.argmax(x, axis=-1)
expected = np.array([1, 0, 0]).astype(np.int32)
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
class ArgmaxFuncNet(nn.Cell):
def construct(self, x):
return F.argmax(x, dim=-1)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmax_float32_functional():
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_functional_argmax(mode):
"""
Feature: test argmax functional api.
Description: test float32 inputs.
Feature: Test argmax functional api.
Description: Test argmax functional api for Graph and PyNative modes.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
adaptive_argmax_functional(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
adaptive_argmax_functional(np.float32)
context.set_context(mode=mode, device_target="GPU")
x = Tensor([[1, 20, 5], [67, 8, 9], [130, 24, 15]], mstype.float32)
net = ArgmaxFuncNet()
output = net(x)
expect_output = np.array([1, 0, 0]).astype(np.int32)
assert np.allclose(output.asnumpy(), expect_output)
@pytest.mark.level0
@ -234,12 +236,12 @@ def test_argmax_functional():
Expectation: the result match with expected result.
"""
x = Tensor([[1, 3, 2], [4, 6, 5], [7, 9, 8]], mstype.int32)
out_dim_none = ops.argmax(x, axis=None, keepdims=False)
out_dim_0 = ops.argmax(x, axis=0, keepdims=False)
out_dim_1 = ops.argmax(x, axis=1, keepdims=False)
out_dim_none_keepdim = ops.argmax(x, axis=None, keepdims=True)
out_dim_0_keepdim = ops.argmax(x, axis=0, keepdims=True)
out_dim_1_keepdim = ops.argmax(x, axis=1, keepdims=True)
out_dim_none = F.argmax(x, dim=None, keepdim=False)
out_dim_0 = F.argmax(x, dim=0, keepdim=False)
out_dim_1 = F.argmax(x, dim=1, keepdim=False)
out_dim_none_keepdim = F.argmax(x, dim=None, keepdim=True)
out_dim_0_keepdim = F.argmax(x, dim=0, keepdim=True)
out_dim_1_keepdim = F.argmax(x, dim=1, keepdim=True)
assert out_dim_none.asnumpy() == 7
assert np.all(out_dim_0.asnumpy() == np.array([2, 2, 2]))

View File

@ -159,7 +159,7 @@ def test_div_trunc_tensor_api():
[-0.2601, -0.2397, 0.5832, 0.2250],
[0.0322, 0.7103, 0.6315, -0.8621]]))
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
output = x.div(y, 'trunc')
output = x.div(y, rounding_mode='trunc')
expected = np.array([[0., -0., -0., 0.],
[1., 1., 0., -0.],
[-0., 0., -1., -0.],
@ -178,7 +178,7 @@ def test_div_floor_tensor_api():
[-0.2601, -0.2397, 0.5832, 0.2250],
[0.0322, 0.7103, 0.6315, -0.8621]]))
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
output = x.div(y, 'floor')
output = x.div(y, rounding_mode='floor')
expected = np.array([[0., -1., -1., 0.],
[1., 1., 0., -1.],
[-1., 0., -2., -1.],
@ -197,7 +197,7 @@ def test_div_functional_api():
[0.1062, 1.4581, 0.7759, -1.2344],
[-0.1830, -0.0313, 1.1908, -1.4757]]))
y = Tensor(np.array([0.8032, 0.2930, -0.8113, -0.2308]))
output = F.div(x, y)
output = F.div(x, y, rounding_mode=None)
expected = np.array([[-0.4620, -6.6051, 0.5676, 1.2639],
[0.2260, -3.4509, -1.2086, 6.8990],
[0.1322, 4.9764, -0.9564, 5.3484],
@ -216,7 +216,7 @@ def test_div_trunc_functional_api():
[-0.2601, -0.2397, 0.5832, 0.2250],
[0.0322, 0.7103, 0.6315, -0.8621]]))
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
output = F.div(x, y, 'trunc')
output = F.div(x, y, rounding_mode='trunc')
expected = np.array([[0., -0., -0., 0.],
[1., 1., 0., -0.],
[-0., 0., -1., -0.],
@ -235,7 +235,7 @@ def test_div_floor_functional_api():
[-0.2601, -0.2397, 0.5832, 0.2250],
[0.0322, 0.7103, 0.6315, -0.8621]]))
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
output = F.div(x, y, 'floor')
output = F.div(x, y, rounding_mode='floor')
expected = np.array([[0., -1., -1., 0.],
[1., 1., 0., -1.],
[-1., 0., -2., -1.],

View File

@ -0,0 +1,46 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.context as context
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
class UnfoldFuncNet(nn.Cell):
def construct(self, x):
return F.unfold(x, kernel_size=3, dilation=1, stride=1)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_unfold_functional_api_modes(mode):
"""
Feature: Test unfold functional api.
Description: Test unfold functional api for Graph and PyNative modes.
Expectation: the result match with expected result.
"""
context.set_context(mode=mode, device_target="CPU")
x = Tensor(np.ones((4, 4, 32, 32)), mstype.float32)
net = UnfoldFuncNet()
output = net(x)
expected_shape = (4, 36, 30, 30)
assert output.dtype == x.dtype
assert output.shape == expected_shape

View File

@ -0,0 +1,47 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.context as context
from mindspore.common import dtype as mstype
class CholeskyTensorNet(nn.Cell):
def construct(self, x):
return x.cholesky(upper=False)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_tensor_cholesky(mode):
"""
Feature: Test cholesky tensor api.
Description: Test cholesky tensor api for Graph and PyNative modes.
Expectation: the result match with expected result.
"""
context.set_context(mode=mode)
x = Tensor([[1.0, 1.0], [1.0, 2.0]], mstype.float32)
net = CholeskyTensorNet()
output = net(x)
expect_output = np.array([[1., 0.], [1., 1.]])
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,47 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.context as context
from mindspore.common import dtype as mstype
class CholeskyInverseTensorNet(nn.Cell):
def construct(self, x):
return x.cholesky_inverse()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_tensor_cholesky_inverse(mode):
"""
Feature: Test cholesky_inverse tensor api.
Description: Test cholesky_inverse tensor api for Graph and PyNative modes.
Expectation: the result match with expected result.
"""
context.set_context(mode=mode)
x = Tensor([[2, 0, 0], [4, 1, 0], [-1, 1, 2]], mstype.float32)
net = CholeskyInverseTensorNet()
output = net(x)
expect_output = np.array([[5.8125, -2.625, 0.625], [-2.625, 1.25, -0.25], [0.625, -0.25, 0.25]])
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,47 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.context as context
from mindspore.common import dtype as mstype
class UnfoldTensorNet(nn.Cell):
def construct(self, x):
return x.unfold(kernel_size=3, dilation=1, stride=1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_tensor_unfold(mode):
"""
Feature: Test unfold tensor api.
Description: Test unfold tensor api for Graph and PyNative modes.
Expectation: the result match with expected result.
"""
context.set_context(mode=mode)
x = Tensor(np.ones((4, 4, 32, 32)), mstype.float32)
net = UnfoldTensorNet()
output = net(x)
expected_shape = (4, 36, 30, 30)
assert output.dtype == x.dtype
assert output.shape == expected_shape