add api: ops.det

This commit is contained in:
guozhibin 2023-01-30 14:44:56 +08:00
parent 360add4014
commit 876acb5fca
11 changed files with 82 additions and 46 deletions

View File

@ -250,8 +250,6 @@ mindspore.ops
mindspore.ops.logical_or
mindspore.ops.logical_xor
mindspore.ops.logit
mindspore.ops.log_matrix_determinant
mindspore.ops.matrix_determinant
mindspore.ops.mul
mindspore.ops.multiply
mindspore.ops.mvlgamma
@ -366,15 +364,18 @@ Reduction函数
mindspore.ops.cholesky
mindspore.ops.cholesky_inverse
mindspore.ops.batch_dot
mindspore.ops.det
mindspore.ops.dot
mindspore.ops.inner
mindspore.ops.inverse
mindspore.ops.ger
mindspore.ops.kron
mindspore.ops.log_matrix_determinant
mindspore.ops.matmul
mindspore.ops.matrix_solve
mindspore.ops.matrix_exp
mindspore.ops.matrix_band_part
mindspore.ops.matrix_determinant
mindspore.ops.matrix_diag
mindspore.ops.matrix_diag_part
mindspore.ops.matrix_set_diag

View File

@ -3,15 +3,4 @@ mindspore.Tensor.swapaxes
.. py:method:: mindspore.Tensor.swapaxes(axis1, axis2)
交换Tensor的两个维度。
参数:
- **axis1** (int) - 第一个维度。
- **axis2** (int) - 第二个维度。
返回:
转化后的Tensor与输入具有相同的数据类型。
异常:
- **TypeError** - `axis1``axis2` 不是整数。
- **ValueError** - `axis1``axis2` 不在 `[-ndim, ndim-1]` 范围内。
详情请参考 :func:`mindspore.ops.swapaxes`

View File

@ -8,7 +8,7 @@ mindspore.ops.TripletMarginLoss
创建一个标准用于计算输入Tensor :math:`x`:math:`x2`:math:`x3` 与大于 :math:`0``margin` 之间的三元组损失值。
可以用来测量样本之间的相似度。一个三元组包含 `a``p``n` (即分别代表示 `anchor``positive examples``negative examples` )。
所有输入Tensor的shape都应该为 :math:`(N, D)`
距离交换在V. Balntas、E. Riba等人论文 `Learning local feature descriptors with triplets and shallow convolutional neural networks <http://158.109.8.37/files/BRP2016.pdf>`_ 中有详细的阐述。
距离交换在V. Balntas、E. Riba等人论文 `Learning local feature descriptors with triplets and shallow convolutional neural networks <http://158.109.8.37/files/BRP2016.pdf>`_ 中有详细的阐述。
对于每个小批量样本,损失值为:

View File

@ -0,0 +1,6 @@
mindspore.ops.det
================================
.. py:function:: mindspore.ops.det(x)
:func:`mindspore.ops.matrix_determinant` 的别名。

View File

@ -250,8 +250,6 @@ Element-wise Operations
mindspore.ops.logical_or
mindspore.ops.logical_xor
mindspore.ops.logit
mindspore.ops.log_matrix_determinant
mindspore.ops.matrix_determinant
mindspore.ops.mul
mindspore.ops.multiply
mindspore.ops.mvlgamma
@ -366,15 +364,18 @@ Linear Algebraic Functions
mindspore.ops.cholesky
mindspore.ops.cholesky_inverse
mindspore.ops.batch_dot
mindspore.ops.det
mindspore.ops.dot
mindspore.ops.inner
mindspore.ops.inverse
mindspore.ops.ger
mindspore.ops.kron
mindspore.ops.log_matrix_determinant
mindspore.ops.matmul
mindspore.ops.matrix_solve
mindspore.ops.matrix_exp
mindspore.ops.matrix_band_part
mindspore.ops.matrix_determinant
mindspore.ops.matrix_diag
mindspore.ops.matrix_diag_part
mindspore.ops.matrix_set_diag

View File

@ -1321,7 +1321,7 @@ def resize(x, *new_shape):
def det(x):
"""Computes the determinant of one or more square matrices."""
return F.matrix_determinant(x)
return F.det(x)
def diagonal(x, offset=0, axis1=0, axis2=1):

View File

@ -1732,29 +1732,7 @@ class Tensor(Tensor_):
def swapaxes(self, axis1, axis2):
"""
Interchange two axes of a tensor.
Args:
axis1 (int): First axis.
axis2 (int): Second axis.
Returns:
Transposed tensor, has the same data type as the input.
Raises:
TypeError: If `axis1` or `axis2` is not integer.
ValueError: If `axis1` or `axis2` is not in the range of :math:`[-ndim, ndim-1]`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.ones((2,3,4), dtype=np.float32))
>>> output = x.swapaxes(0, 2)
>>> print(output.shape)
(4,3,2)
For details, please refer to :func:`mindspore.ops.swapaxes`.
"""
self._init_check()
return tensor_operator_registry.get('swapaxes')(self, axis1, axis2)
@ -2457,10 +2435,10 @@ class Tensor(Tensor_):
def det(self):
r"""
Refer to :func:`mindspore.Tensor.matrix_determinant`.
Refer to :func:`mindspore.ops.det`.
"""
self._init_check()
return tensor_operator_registry.get('matrix_determinant')(self)
return tensor_operator_registry.get('det')(self)
def diff(self, n=1, axis=-1, prepend=None, append=None):
r"""

View File

@ -234,6 +234,7 @@ from .math_func import (
log_matrix_determinant,
slogdet,
matrix_determinant,
det,
linspace,
matrix_solve,
maximum,

View File

@ -1308,8 +1308,8 @@ def logdet(x):
>>> print(output)
[1.9459091 0.6931454]
"""
det = matrix_determinant(x)
return log_(det)
det_x = matrix_determinant(x)
return log_(det_x)
def floor(x):
@ -3224,7 +3224,17 @@ def matrix_determinant(x):
>>> print(output)
[-16.5 21. ]
"""
return matrix_determinant_(x)
return _get_cache_prim(P.MatrixDeterminant)()(x)
def det(x):
"""
Alias for :func:`mindspore.ops.matrix_determinant` .
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
return matrix_determinant(x)
def matrix_exp(x):
@ -10017,6 +10027,7 @@ __all__ = [
'logdet',
'log_matrix_determinant',
'matrix_determinant',
'det',
'linspace',
'matrix_solve',
'std',

View File

@ -196,6 +196,7 @@ tensor_operator_registry.register('random_categorical', random_categorical)
tensor_operator_registry.register('mirror_pad', P.MirrorPad)
tensor_operator_registry.register('minimum', P.Minimum)
tensor_operator_registry.register('matrix_determinant', matrix_determinant)
tensor_operator_registry.register('det', det)
tensor_operator_registry.register('log1p', log1p)
tensor_operator_registry.register('logdet', logdet)
tensor_operator_registry.register('log_matrix_determinant', log_matrix_determinant)

View File

@ -0,0 +1,48 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import context
class Net(nn.Cell):
def construct(self, x):
return ops.det(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_det(mode):
"""
Feature: ops.det(x)
Description: Verify the result of ops.det(x)
Expectation: success
"""
context.set_context(mode=mode)
net = Net()
x = Tensor([[1.5, 2.0], [3, 4.6]], dtype=mstype.float32)
output = net(x)
expected = np.array(0.9)
assert np.allclose(output.asnumpy(), expected)