forked from mindspore-Ecosystem/mindspore
!48317 [ST][MS][OPS] Fix operation validation issues.
Merge pull request !48317 from alashkari/val-updates
This commit is contained in:
commit
1c4e24f1b1
|
@ -393,6 +393,7 @@ Reduction函数
|
|||
mindspore.ops.inverse
|
||||
mindspore.ops.ger
|
||||
mindspore.ops.kron
|
||||
mindspore.ops.lstsq
|
||||
mindspore.ops.matmul
|
||||
mindspore.ops.matrix_solve
|
||||
mindspore.ops.matrix_exp
|
||||
|
|
|
@ -186,6 +186,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.lt
|
||||
mindspore.Tensor.masked_fill
|
||||
mindspore.Tensor.masked_select
|
||||
mindspore.Tensor.matmul
|
||||
mindspore.Tensor.matrix_power
|
||||
mindspore.Tensor.max
|
||||
mindspore.Tensor.maximum
|
||||
|
@ -201,6 +202,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.mT
|
||||
mindspore.Tensor.mul
|
||||
mindspore.Tensor.multiply
|
||||
mindspore.Tensor.mvlgamma
|
||||
mindspore.Tensor.nan_to_num
|
||||
mindspore.Tensor.nansum
|
||||
mindspore.Tensor.narrow
|
||||
|
@ -208,6 +210,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.ndim
|
||||
mindspore.Tensor.ndimension
|
||||
mindspore.Tensor.ne
|
||||
mindspore.Tensor.neg
|
||||
mindspore.Tensor.negative
|
||||
mindspore.Tensor.nelement
|
||||
mindspore.Tensor.new_ones
|
||||
|
|
|
@ -192,6 +192,7 @@
|
|||
mindspore.Tensor.lt
|
||||
mindspore.Tensor.masked_fill
|
||||
mindspore.Tensor.masked_select
|
||||
mindspore.Tensor.matmul
|
||||
mindspore.Tensor.matrix_power
|
||||
mindspore.Tensor.max
|
||||
mindspore.Tensor.maximum
|
||||
|
@ -207,6 +208,7 @@
|
|||
mindspore.Tensor.mT
|
||||
mindspore.Tensor.mul
|
||||
mindspore.Tensor.multiply
|
||||
mindspore.Tensor.mvlgamma
|
||||
mindspore.Tensor.nan_to_num
|
||||
mindspore.Tensor.nansum
|
||||
mindspore.Tensor.narrow
|
||||
|
@ -214,6 +216,7 @@
|
|||
mindspore.Tensor.ndim
|
||||
mindspore.Tensor.ndimension
|
||||
mindspore.Tensor.ne
|
||||
mindspore.Tensor.neg
|
||||
mindspore.Tensor.negative
|
||||
mindspore.Tensor.nelement
|
||||
mindspore.Tensor.new_ones
|
||||
|
|
|
@ -393,6 +393,7 @@ Linear Algebraic Functions
|
|||
mindspore.ops.inverse
|
||||
mindspore.ops.ger
|
||||
mindspore.ops.kron
|
||||
mindspore.ops.lstsq
|
||||
mindspore.ops.matmul
|
||||
mindspore.ops.matrix_solve
|
||||
mindspore.ops.matrix_exp
|
||||
|
|
|
@ -3365,7 +3365,7 @@ def true_divide(divident, divisor):
|
|||
r"""
|
||||
Computes the element-wise division of input tensors.
|
||||
"""
|
||||
return F.div(divident, divisor, None)
|
||||
return F.div(divident, divisor, rounding_mode=None)
|
||||
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
@ -4012,11 +4012,11 @@ def multiply(input, other):
|
|||
return F.multiply(input, other)
|
||||
|
||||
|
||||
def div(input, other, rounding_mode=None):
|
||||
def div(input, value, *, rounding_mode=None):
|
||||
r"""
|
||||
Divides the tensor `input` by the given input tensor `other` in floating-point type element-wise.
|
||||
Divides the tensor `input` by the given input tensor `value` in floating-point type element-wise.
|
||||
"""
|
||||
return F.div(input, other, rounding_mode)
|
||||
return F.div(input, value, rounding_mode=rounding_mode)
|
||||
|
||||
|
||||
def equal(x, y):
|
||||
|
|
|
@ -827,7 +827,7 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|||
For details, please refer to :func:`mindspore.ops.div`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('div')(self, value, None)
|
||||
return tensor_operator_registry.get('div')(self, value, rounding_mode=None)
|
||||
|
||||
def triu(self, diagonal=0):
|
||||
r"""
|
||||
|
@ -3670,19 +3670,19 @@ class Tensor(Tensor_, metaclass=_TensorMeta):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('multiply')(self, value)
|
||||
|
||||
def div(self, other, rounding_mode=None):
|
||||
def div(self, value, *, rounding_mode=None):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.div`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('div')(self, other, rounding_mode)
|
||||
return tensor_operator_registry.get('div')(self, value, rounding_mode=rounding_mode)
|
||||
|
||||
def divide(self, value, *, rounding_mode=None):
|
||||
r"""
|
||||
Alias for :func:`mindspore.Tensor.div`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('div')(self, value, rounding_mode)
|
||||
return tensor_operator_registry.get('div')(self, value, rounding_mode=rounding_mode)
|
||||
|
||||
def equal(self, other):
|
||||
r"""
|
||||
|
|
|
@ -49,7 +49,8 @@ from mindspore.ops.operations.array_ops import (
|
|||
Lstsq,
|
||||
Mvlgamma,
|
||||
CountNonZero,
|
||||
Tril
|
||||
Tril,
|
||||
Argmax
|
||||
)
|
||||
from mindspore.ops.operations.array_ops import TensorScatterElements
|
||||
from mindspore.common import Tensor
|
||||
|
@ -5477,43 +5478,42 @@ def max(x, axis=0, keep_dims=False):
|
|||
return argmax_with_value_op(x)
|
||||
|
||||
|
||||
def argmax(x, axis=None, keepdims=False):
|
||||
def argmax(input, dim=None, keepdim=False):
|
||||
"""
|
||||
Return the indices of the maximum values of a tensor across a dimension.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
axis (Union[int, None], optional): The dimension to reduce. If `axis` is None,
|
||||
the indices of the maximum value within the flattened input will be returned.
|
||||
Default: None.
|
||||
keepdims (bool, optional): Whether the output tensor retains the specified
|
||||
dimension. Ignored if `axis` is None. Default: False.
|
||||
input (Tensor): Input tensor.
|
||||
dim (Union[int, None]): The dimension to reduce. If `dim` is None, the indices of the maximum
|
||||
value within the flattened input will be returned. Default: None.
|
||||
keepdim (bool): Whether the output tensor retains the specified
|
||||
dimension. Ignored if `dim` is None. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, indices of the maximum values across a dimension.
|
||||
|
||||
Raises:
|
||||
ValueError: If `axis` is out of range.
|
||||
ValueError: If `dim` is out of range.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[1, 20, 5], [67, 8, 9], [130, 24, 15]]).astype(np.float32))
|
||||
>>> output = ops.argmax(x, axis=-1)
|
||||
>>> output = ops.argmax(x, dim=-1)
|
||||
>>> print(output)
|
||||
[1 0 0]
|
||||
"""
|
||||
if x.shape == ():
|
||||
if input.shape == ():
|
||||
return Tensor(0)
|
||||
is_axis_none = False
|
||||
if axis is None:
|
||||
x = reshape_(x, (-1,))
|
||||
axis = 0
|
||||
is_axis_none = True
|
||||
out = P.Argmax(axis, mstype.int64)(x)
|
||||
if keepdims and not is_axis_none:
|
||||
out = expand_dims_(out, axis)
|
||||
is_dim_none = False
|
||||
if dim is None:
|
||||
input = reshape_(input, (-1,))
|
||||
dim = 0
|
||||
is_dim_none = True
|
||||
out = _get_cache_prim(Argmax)(dim, mstype.int64)(input)
|
||||
if keepdim and not is_dim_none:
|
||||
out = expand_dims_(out, dim)
|
||||
return out
|
||||
|
||||
|
||||
|
|
|
@ -883,7 +883,7 @@ def multiply(input, other):
|
|||
return tensor_mul(input, other)
|
||||
|
||||
|
||||
def div(input, other, rounding_mode=None):
|
||||
def div(input, other, *, rounding_mode=None):
|
||||
"""
|
||||
Divides the first input tensor by the second input tensor in floating-point type element-wise.
|
||||
|
||||
|
@ -948,7 +948,7 @@ def divide(x, other, *, rounding_mode=None):
|
|||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
return div(x, other, rounding_mode)
|
||||
return div(x, other, rounding_mode=rounding_mode)
|
||||
|
||||
|
||||
def float_power(x, exponent):
|
||||
|
|
|
@ -17,9 +17,11 @@ import pytest
|
|||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, ops
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import jit
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
|
@ -42,24 +44,25 @@ def test_net():
|
|||
print(output.asnumpy())
|
||||
|
||||
|
||||
def adaptive_argmax_functional(nptype):
|
||||
x = Tensor(np.array([[1, 20, 5], [67, 8, 9], [130, 24, 15]]).astype(nptype))
|
||||
output = ops.argmax(x, axis=-1)
|
||||
expected = np.array([1, 0, 0]).astype(np.int32)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
class ArgmaxFuncNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return F.argmax(x, dim=-1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_argmax_float32_functional():
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_functional_argmax(mode):
|
||||
"""
|
||||
Feature: test argmax functional api.
|
||||
Description: test float32 inputs.
|
||||
Feature: Test argmax functional api.
|
||||
Description: Test argmax functional api for Graph and PyNative modes.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
adaptive_argmax_functional(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
adaptive_argmax_functional(np.float32)
|
||||
context.set_context(mode=mode, device_target="Ascend")
|
||||
x = Tensor([[1, 20, 5], [67, 8, 9], [130, 24, 15]], mstype.float32)
|
||||
net = ArgmaxFuncNet()
|
||||
output = net(x)
|
||||
expect_output = np.array([1, 0, 0]).astype(np.int32)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
|
|
|
@ -22,6 +22,7 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, ops
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
@ -94,7 +95,7 @@ def test_argmax_high_dims():
|
|||
|
||||
def adaptive_argmax_functional(nptype):
|
||||
x = Tensor(np.array([[1, 20, 5], [67, 8, 9], [130, 24, 15]]).astype(nptype))
|
||||
output = ops.argmax(x, axis=-1)
|
||||
output = F.argmax(x, dim=-1)
|
||||
expected = np.array([1, 0, 0]).astype(np.int32)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
|
||||
|
@ -225,12 +226,12 @@ def test_argmax_functional():
|
|||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor([[1, 3, 2], [4, 6, 5], [7, 9, 8]], mstype.int32)
|
||||
out_dim_none = ops.argmax(x, axis=None, keepdims=False)
|
||||
out_dim_0 = ops.argmax(x, axis=0, keepdims=False)
|
||||
out_dim_1 = ops.argmax(x, axis=1, keepdims=False)
|
||||
out_dim_none_keepdim = ops.argmax(x, axis=None, keepdims=True)
|
||||
out_dim_0_keepdim = ops.argmax(x, axis=0, keepdims=True)
|
||||
out_dim_1_keepdim = ops.argmax(x, axis=1, keepdims=True)
|
||||
out_dim_none = F.argmax(x, dim=None, keepdim=False)
|
||||
out_dim_0 = F.argmax(x, dim=0, keepdim=False)
|
||||
out_dim_1 = F.argmax(x, dim=1, keepdim=False)
|
||||
out_dim_none_keepdim = F.argmax(x, dim=None, keepdim=True)
|
||||
out_dim_0_keepdim = F.argmax(x, dim=0, keepdim=True)
|
||||
out_dim_1_keepdim = F.argmax(x, dim=1, keepdim=True)
|
||||
|
||||
assert out_dim_none.asnumpy() == 7
|
||||
assert np.all(out_dim_0.asnumpy() == np.array([2, 2, 2]))
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class CholeskyFuncNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return F.cholesky(x, upper=False)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_functional_cholesky(mode):
|
||||
"""
|
||||
Feature: Test cholesky functional api.
|
||||
Description: Test cholesky functional api for Graph and PyNative modes.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=mode, device_target="CPU")
|
||||
x = Tensor([[1.0, 1.0], [1.0, 2.0]], mstype.float32)
|
||||
net = CholeskyFuncNet()
|
||||
output = net(x)
|
||||
expect_output = np.array([[1., 0.], [1., 1.]])
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
|
@ -94,7 +94,7 @@ def test_div_trunc_tensor_api():
|
|||
[-0.2601, -0.2397, 0.5832, 0.2250],
|
||||
[0.0322, 0.7103, 0.6315, -0.8621]]))
|
||||
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
|
||||
output = x.div(y, 'trunc')
|
||||
output = x.div(y, rounding_mode='trunc')
|
||||
expected = np.array([[0., -0., -0., 0.],
|
||||
[1., 1., 0., -0.],
|
||||
[-0., 0., -1., -0.],
|
||||
|
@ -113,7 +113,7 @@ def test_div_floor_tensor_api():
|
|||
[-0.2601, -0.2397, 0.5832, 0.2250],
|
||||
[0.0322, 0.7103, 0.6315, -0.8621]]))
|
||||
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
|
||||
output = x.div(y, 'floor')
|
||||
output = x.div(y, rounding_mode='floor')
|
||||
expected = np.array([[0., -1., -1., 0.],
|
||||
[1., 1., 0., -1.],
|
||||
[-1., 0., -2., -1.],
|
||||
|
@ -132,7 +132,7 @@ def test_div_functional_api():
|
|||
[0.1062, 1.4581, 0.7759, -1.2344],
|
||||
[-0.1830, -0.0313, 1.1908, -1.4757]]))
|
||||
y = Tensor(np.array([0.8032, 0.2930, -0.8113, -0.2308]))
|
||||
output = F.div(x, y)
|
||||
output = F.div(x, y, rounding_mode=None)
|
||||
expected = np.array([[-0.4620, -6.6051, 0.5676, 1.2639],
|
||||
[0.2260, -3.4509, -1.2086, 6.8990],
|
||||
[0.1322, 4.9764, -0.9564, 5.3484],
|
||||
|
@ -151,7 +151,7 @@ def test_div_trunc_functional_api():
|
|||
[-0.2601, -0.2397, 0.5832, 0.2250],
|
||||
[0.0322, 0.7103, 0.6315, -0.8621]]))
|
||||
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
|
||||
output = F.div(x, y, 'trunc')
|
||||
output = F.div(x, y, rounding_mode='trunc')
|
||||
expected = np.array([[0., -0., -0., 0.],
|
||||
[1., 1., 0., -0.],
|
||||
[-0., 0., -1., -0.],
|
||||
|
@ -170,7 +170,7 @@ def test_div_floor_functional_api():
|
|||
[-0.2601, -0.2397, 0.5832, 0.2250],
|
||||
[0.0322, 0.7103, 0.6315, -0.8621]]))
|
||||
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
|
||||
output = F.div(x, y, 'floor')
|
||||
output = F.div(x, y, rounding_mode='floor')
|
||||
expected = np.array([[0., -1., -1., 0.],
|
||||
[1., 1., 0., -1.],
|
||||
[-1., 0., -2., -1.],
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class UnfoldFuncNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return F.unfold(x, kernel_size=3, dilation=1, stride=1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_unfold_functional_api_modes(mode):
|
||||
"""
|
||||
Feature: Test unfold functional api.
|
||||
Description: Test unfold functional api for Graph and PyNative modes.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=mode, device_target="CPU")
|
||||
x = Tensor(np.ones((4, 4, 32, 32)), mstype.float32)
|
||||
net = UnfoldFuncNet()
|
||||
output = net(x)
|
||||
expected_shape = (4, 36, 30, 30)
|
||||
assert output.dtype == x.dtype
|
||||
assert output.shape == expected_shape
|
|
@ -22,6 +22,7 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, ops
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
class NetArgmax(nn.Cell):
|
||||
|
@ -100,26 +101,27 @@ def test_argmax_high_dims():
|
|||
assert (ms_output.asnumpy() == np_output).all()
|
||||
|
||||
|
||||
def adaptive_argmax_functional(nptype):
|
||||
x = Tensor(np.array([[1, 20, 5], [67, 8, 9], [130, 24, 15]]).astype(nptype))
|
||||
output = ops.argmax(x, axis=-1)
|
||||
expected = np.array([1, 0, 0]).astype(np.int32)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expected)
|
||||
class ArgmaxFuncNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return F.argmax(x, dim=-1)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_argmax_float32_functional():
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_functional_argmax(mode):
|
||||
"""
|
||||
Feature: test argmax functional api.
|
||||
Description: test float32 inputs.
|
||||
Feature: Test argmax functional api.
|
||||
Description: Test argmax functional api for Graph and PyNative modes.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
adaptive_argmax_functional(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
adaptive_argmax_functional(np.float32)
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
x = Tensor([[1, 20, 5], [67, 8, 9], [130, 24, 15]], mstype.float32)
|
||||
net = ArgmaxFuncNet()
|
||||
output = net(x)
|
||||
expect_output = np.array([1, 0, 0]).astype(np.int32)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -234,12 +236,12 @@ def test_argmax_functional():
|
|||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor([[1, 3, 2], [4, 6, 5], [7, 9, 8]], mstype.int32)
|
||||
out_dim_none = ops.argmax(x, axis=None, keepdims=False)
|
||||
out_dim_0 = ops.argmax(x, axis=0, keepdims=False)
|
||||
out_dim_1 = ops.argmax(x, axis=1, keepdims=False)
|
||||
out_dim_none_keepdim = ops.argmax(x, axis=None, keepdims=True)
|
||||
out_dim_0_keepdim = ops.argmax(x, axis=0, keepdims=True)
|
||||
out_dim_1_keepdim = ops.argmax(x, axis=1, keepdims=True)
|
||||
out_dim_none = F.argmax(x, dim=None, keepdim=False)
|
||||
out_dim_0 = F.argmax(x, dim=0, keepdim=False)
|
||||
out_dim_1 = F.argmax(x, dim=1, keepdim=False)
|
||||
out_dim_none_keepdim = F.argmax(x, dim=None, keepdim=True)
|
||||
out_dim_0_keepdim = F.argmax(x, dim=0, keepdim=True)
|
||||
out_dim_1_keepdim = F.argmax(x, dim=1, keepdim=True)
|
||||
|
||||
assert out_dim_none.asnumpy() == 7
|
||||
assert np.all(out_dim_0.asnumpy() == np.array([2, 2, 2]))
|
||||
|
|
|
@ -159,7 +159,7 @@ def test_div_trunc_tensor_api():
|
|||
[-0.2601, -0.2397, 0.5832, 0.2250],
|
||||
[0.0322, 0.7103, 0.6315, -0.8621]]))
|
||||
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
|
||||
output = x.div(y, 'trunc')
|
||||
output = x.div(y, rounding_mode='trunc')
|
||||
expected = np.array([[0., -0., -0., 0.],
|
||||
[1., 1., 0., -0.],
|
||||
[-0., 0., -1., -0.],
|
||||
|
@ -178,7 +178,7 @@ def test_div_floor_tensor_api():
|
|||
[-0.2601, -0.2397, 0.5832, 0.2250],
|
||||
[0.0322, 0.7103, 0.6315, -0.8621]]))
|
||||
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
|
||||
output = x.div(y, 'floor')
|
||||
output = x.div(y, rounding_mode='floor')
|
||||
expected = np.array([[0., -1., -1., 0.],
|
||||
[1., 1., 0., -1.],
|
||||
[-1., 0., -2., -1.],
|
||||
|
@ -197,7 +197,7 @@ def test_div_functional_api():
|
|||
[0.1062, 1.4581, 0.7759, -1.2344],
|
||||
[-0.1830, -0.0313, 1.1908, -1.4757]]))
|
||||
y = Tensor(np.array([0.8032, 0.2930, -0.8113, -0.2308]))
|
||||
output = F.div(x, y)
|
||||
output = F.div(x, y, rounding_mode=None)
|
||||
expected = np.array([[-0.4620, -6.6051, 0.5676, 1.2639],
|
||||
[0.2260, -3.4509, -1.2086, 6.8990],
|
||||
[0.1322, 4.9764, -0.9564, 5.3484],
|
||||
|
@ -216,7 +216,7 @@ def test_div_trunc_functional_api():
|
|||
[-0.2601, -0.2397, 0.5832, 0.2250],
|
||||
[0.0322, 0.7103, 0.6315, -0.8621]]))
|
||||
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
|
||||
output = F.div(x, y, 'trunc')
|
||||
output = F.div(x, y, rounding_mode='trunc')
|
||||
expected = np.array([[0., -0., -0., 0.],
|
||||
[1., 1., 0., -0.],
|
||||
[-0., 0., -1., -0.],
|
||||
|
@ -235,7 +235,7 @@ def test_div_floor_functional_api():
|
|||
[-0.2601, -0.2397, 0.5832, 0.2250],
|
||||
[0.0322, 0.7103, 0.6315, -0.8621]]))
|
||||
y = Tensor(np.array([0.6962, -0.4668, -0.2971, -0.6389]))
|
||||
output = F.div(x, y, 'floor')
|
||||
output = F.div(x, y, rounding_mode='floor')
|
||||
expected = np.array([[0., -1., -1., 0.],
|
||||
[1., 1., 0., -1.],
|
||||
[-1., 0., -2., -1.],
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class UnfoldFuncNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return F.unfold(x, kernel_size=3, dilation=1, stride=1)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_unfold_functional_api_modes(mode):
|
||||
"""
|
||||
Feature: Test unfold functional api.
|
||||
Description: Test unfold functional api for Graph and PyNative modes.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=mode, device_target="CPU")
|
||||
x = Tensor(np.ones((4, 4, 32, 32)), mstype.float32)
|
||||
net = UnfoldFuncNet()
|
||||
output = net(x)
|
||||
expected_shape = (4, 36, 30, 30)
|
||||
assert output.dtype == x.dtype
|
||||
assert output.shape == expected_shape
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class CholeskyTensorNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x.cholesky(upper=False)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_tensor_cholesky(mode):
|
||||
"""
|
||||
Feature: Test cholesky tensor api.
|
||||
Description: Test cholesky tensor api for Graph and PyNative modes.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
x = Tensor([[1.0, 1.0], [1.0, 2.0]], mstype.float32)
|
||||
net = CholeskyTensorNet()
|
||||
output = net(x)
|
||||
expect_output = np.array([[1., 0.], [1., 1.]])
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class CholeskyInverseTensorNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x.cholesky_inverse()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_tensor_cholesky_inverse(mode):
|
||||
"""
|
||||
Feature: Test cholesky_inverse tensor api.
|
||||
Description: Test cholesky_inverse tensor api for Graph and PyNative modes.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
x = Tensor([[2, 0, 0], [4, 1, 0], [-1, 1, 2]], mstype.float32)
|
||||
net = CholeskyInverseTensorNet()
|
||||
output = net(x)
|
||||
expect_output = np.array([[5.8125, -2.625, 0.625], [-2.625, 1.25, -0.25], [0.625, -0.25, 0.25]])
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class UnfoldTensorNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x.unfold(kernel_size=3, dilation=1, stride=1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_tensor_unfold(mode):
|
||||
"""
|
||||
Feature: Test unfold tensor api.
|
||||
Description: Test unfold tensor api for Graph and PyNative modes.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
x = Tensor(np.ones((4, 4, 32, 32)), mstype.float32)
|
||||
net = UnfoldTensorNet()
|
||||
output = net(x)
|
||||
expected_shape = (4, 36, 30, 30)
|
||||
assert output.dtype == x.dtype
|
||||
assert output.shape == expected_shape
|
Loading…
Reference in New Issue