forked from mindspore-Ecosystem/mindspore
add api: ops.det
This commit is contained in:
parent
360add4014
commit
876acb5fca
|
@ -250,8 +250,6 @@ mindspore.ops
|
|||
mindspore.ops.logical_or
|
||||
mindspore.ops.logical_xor
|
||||
mindspore.ops.logit
|
||||
mindspore.ops.log_matrix_determinant
|
||||
mindspore.ops.matrix_determinant
|
||||
mindspore.ops.mul
|
||||
mindspore.ops.multiply
|
||||
mindspore.ops.mvlgamma
|
||||
|
@ -366,15 +364,18 @@ Reduction函数
|
|||
mindspore.ops.cholesky
|
||||
mindspore.ops.cholesky_inverse
|
||||
mindspore.ops.batch_dot
|
||||
mindspore.ops.det
|
||||
mindspore.ops.dot
|
||||
mindspore.ops.inner
|
||||
mindspore.ops.inverse
|
||||
mindspore.ops.ger
|
||||
mindspore.ops.kron
|
||||
mindspore.ops.log_matrix_determinant
|
||||
mindspore.ops.matmul
|
||||
mindspore.ops.matrix_solve
|
||||
mindspore.ops.matrix_exp
|
||||
mindspore.ops.matrix_band_part
|
||||
mindspore.ops.matrix_determinant
|
||||
mindspore.ops.matrix_diag
|
||||
mindspore.ops.matrix_diag_part
|
||||
mindspore.ops.matrix_set_diag
|
||||
|
|
|
@ -3,15 +3,4 @@ mindspore.Tensor.swapaxes
|
|||
|
||||
.. py:method:: mindspore.Tensor.swapaxes(axis1, axis2)
|
||||
|
||||
交换Tensor的两个维度。
|
||||
|
||||
参数:
|
||||
- **axis1** (int) - 第一个维度。
|
||||
- **axis2** (int) - 第二个维度。
|
||||
|
||||
返回:
|
||||
转化后的Tensor,与输入具有相同的数据类型。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `axis1` 或 `axis2` 不是整数。
|
||||
- **ValueError** - `axis1` 或 `axis2` 不在 `[-ndim, ndim-1]` 范围内。
|
||||
详情请参考 :func:`mindspore.ops.swapaxes` 。
|
||||
|
|
|
@ -8,7 +8,7 @@ mindspore.ops.TripletMarginLoss
|
|||
创建一个标准,用于计算输入Tensor :math:`x` 、 :math:`x2` 和 :math:`x3` 与大于 :math:`0` 的 `margin` 之间的三元组损失值。
|
||||
可以用来测量样本之间的相似度。一个三元组包含 `a` 、 `p` 和 `n` (即分别代表示 `anchor` 、 `positive examples` 和 `negative examples` )。
|
||||
所有输入Tensor的shape都应该为 :math:`(N, D)` 。
|
||||
距离交换在V. Balntas、E. Riba等人在论文 `Learning local feature descriptors with triplets and shallow convolutional neural networks <http://158.109.8.37/files/BRP2016.pdf>`_ 中有详细的阐述。
|
||||
距离交换在V. Balntas、E. Riba等人的论文 `Learning local feature descriptors with triplets and shallow convolutional neural networks <http://158.109.8.37/files/BRP2016.pdf>`_ 中有详细的阐述。
|
||||
|
||||
对于每个小批量样本,损失值为:
|
||||
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.ops.det
|
||||
================================
|
||||
|
||||
.. py:function:: mindspore.ops.det(x)
|
||||
|
||||
:func:`mindspore.ops.matrix_determinant` 的别名。
|
|
@ -250,8 +250,6 @@ Element-wise Operations
|
|||
mindspore.ops.logical_or
|
||||
mindspore.ops.logical_xor
|
||||
mindspore.ops.logit
|
||||
mindspore.ops.log_matrix_determinant
|
||||
mindspore.ops.matrix_determinant
|
||||
mindspore.ops.mul
|
||||
mindspore.ops.multiply
|
||||
mindspore.ops.mvlgamma
|
||||
|
@ -366,15 +364,18 @@ Linear Algebraic Functions
|
|||
mindspore.ops.cholesky
|
||||
mindspore.ops.cholesky_inverse
|
||||
mindspore.ops.batch_dot
|
||||
mindspore.ops.det
|
||||
mindspore.ops.dot
|
||||
mindspore.ops.inner
|
||||
mindspore.ops.inverse
|
||||
mindspore.ops.ger
|
||||
mindspore.ops.kron
|
||||
mindspore.ops.log_matrix_determinant
|
||||
mindspore.ops.matmul
|
||||
mindspore.ops.matrix_solve
|
||||
mindspore.ops.matrix_exp
|
||||
mindspore.ops.matrix_band_part
|
||||
mindspore.ops.matrix_determinant
|
||||
mindspore.ops.matrix_diag
|
||||
mindspore.ops.matrix_diag_part
|
||||
mindspore.ops.matrix_set_diag
|
||||
|
|
|
@ -1321,7 +1321,7 @@ def resize(x, *new_shape):
|
|||
|
||||
def det(x):
|
||||
"""Computes the determinant of one or more square matrices."""
|
||||
return F.matrix_determinant(x)
|
||||
return F.det(x)
|
||||
|
||||
|
||||
def diagonal(x, offset=0, axis1=0, axis2=1):
|
||||
|
|
|
@ -1732,29 +1732,7 @@ class Tensor(Tensor_):
|
|||
|
||||
def swapaxes(self, axis1, axis2):
|
||||
"""
|
||||
Interchange two axes of a tensor.
|
||||
|
||||
Args:
|
||||
axis1 (int): First axis.
|
||||
axis2 (int): Second axis.
|
||||
|
||||
Returns:
|
||||
Transposed tensor, has the same data type as the input.
|
||||
|
||||
Raises:
|
||||
TypeError: If `axis1` or `axis2` is not integer.
|
||||
ValueError: If `axis1` or `axis2` is not in the range of :math:`[-ndim, ndim-1]`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> x = Tensor(np.ones((2,3,4), dtype=np.float32))
|
||||
>>> output = x.swapaxes(0, 2)
|
||||
>>> print(output.shape)
|
||||
(4,3,2)
|
||||
For details, please refer to :func:`mindspore.ops.swapaxes`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('swapaxes')(self, axis1, axis2)
|
||||
|
@ -2457,10 +2435,10 @@ class Tensor(Tensor_):
|
|||
|
||||
def det(self):
|
||||
r"""
|
||||
Refer to :func:`mindspore.Tensor.matrix_determinant`.
|
||||
Refer to :func:`mindspore.ops.det`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('matrix_determinant')(self)
|
||||
return tensor_operator_registry.get('det')(self)
|
||||
|
||||
def diff(self, n=1, axis=-1, prepend=None, append=None):
|
||||
r"""
|
||||
|
|
|
@ -234,6 +234,7 @@ from .math_func import (
|
|||
log_matrix_determinant,
|
||||
slogdet,
|
||||
matrix_determinant,
|
||||
det,
|
||||
linspace,
|
||||
matrix_solve,
|
||||
maximum,
|
||||
|
|
|
@ -1308,8 +1308,8 @@ def logdet(x):
|
|||
>>> print(output)
|
||||
[1.9459091 0.6931454]
|
||||
"""
|
||||
det = matrix_determinant(x)
|
||||
return log_(det)
|
||||
det_x = matrix_determinant(x)
|
||||
return log_(det_x)
|
||||
|
||||
|
||||
def floor(x):
|
||||
|
@ -3224,7 +3224,17 @@ def matrix_determinant(x):
|
|||
>>> print(output)
|
||||
[-16.5 21. ]
|
||||
"""
|
||||
return matrix_determinant_(x)
|
||||
return _get_cache_prim(P.MatrixDeterminant)()(x)
|
||||
|
||||
|
||||
def det(x):
|
||||
"""
|
||||
Alias for :func:`mindspore.ops.matrix_determinant` .
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
return matrix_determinant(x)
|
||||
|
||||
|
||||
def matrix_exp(x):
|
||||
|
@ -10017,6 +10027,7 @@ __all__ = [
|
|||
'logdet',
|
||||
'log_matrix_determinant',
|
||||
'matrix_determinant',
|
||||
'det',
|
||||
'linspace',
|
||||
'matrix_solve',
|
||||
'std',
|
||||
|
|
|
@ -196,6 +196,7 @@ tensor_operator_registry.register('random_categorical', random_categorical)
|
|||
tensor_operator_registry.register('mirror_pad', P.MirrorPad)
|
||||
tensor_operator_registry.register('minimum', P.Minimum)
|
||||
tensor_operator_registry.register('matrix_determinant', matrix_determinant)
|
||||
tensor_operator_registry.register('det', det)
|
||||
tensor_operator_registry.register('log1p', log1p)
|
||||
tensor_operator_registry.register('logdet', logdet)
|
||||
tensor_operator_registry.register('log_matrix_determinant', log_matrix_determinant)
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return ops.det(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_det(mode):
|
||||
"""
|
||||
Feature: ops.det(x)
|
||||
Description: Verify the result of ops.det(x)
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = Tensor([[1.5, 2.0], [3, 4.6]], dtype=mstype.float32)
|
||||
output = net(x)
|
||||
expected = np.array(0.9)
|
||||
assert np.allclose(output.asnumpy(), expected)
|
Loading…
Reference in New Issue