!46320 ops_view_einsum_master

Merge pull request !46320 from yide12/tensor_view_einsum_master
This commit is contained in:
i-robot 2022-12-13 09:36:13 +00:00 committed by Gitee
commit 5d1f943a82
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 295 additions and 0 deletions

View File

@ -416,6 +416,7 @@ Array操作
mindspore.ops.diagonal
mindspore.ops.dyn_shape
mindspore.ops.dsplit
mindspore.ops.einsum
mindspore.ops.expand
mindspore.ops.expand_dims
mindspore.ops.flip
@ -488,6 +489,7 @@ Array操作
mindspore.ops.unsorted_segment_sum
mindspore.ops.unsqueeze
mindspore.ops.unstack
mindspore.ops.view_as_real
mindspore.ops.vsplit
mindspore.ops.where

View File

@ -0,0 +1,16 @@
mindspore.ops.einsum
====================
.. py:function:: mindspore.ops.einsum(equation, *operands)
基于爱因斯坦求和约定Einsum符号指定维度对输入Tensor元素的乘积求和。你可以使用这个运算符来执行对角线、减法、转置、矩阵乘法、乘法、内积运算等等。
参数:
- **equation** str - 基于爱因斯坦求和约定的符号表示想要执行的操作。符号只能包含字母、逗号、省略号和箭头。字母表示输入Tensor维数逗号表示单独的Tensor省略号表示忽略的Tensor维数箭头的左边表示输入Tensor右边表示期望输出的维度。
- **operands** Tensor - 用于计算的输入Tensor。Tensor的数据类型必须相同。
返回:
Tensorshape可以根据 `equation` 得到。数据类型和输入Tensor相同。
异常:
- **TypeError** - `equation` 无效或者不匹配输入Tensor。

View File

@ -0,0 +1,15 @@
mindspore.ops.view_as_real
==========================
.. py:function:: mindspore.ops.view_as_real(x)
将复数Tensor看作实数Tensor。返回的实数Tensor的最后一维大小为2由复数的实部和虚部组成。
参数:
- **x** (Tensor) - 输入必须是一个复数Tensor。
返回:
实数Tensor。
异常:
- **TypeError** - 输入Tensor不是复数类型。

View File

@ -416,6 +416,7 @@ Array Operation
mindspore.ops.diagonal
mindspore.ops.dsplit
mindspore.ops.dyn_shape
mindspore.ops.einsum
mindspore.ops.expand
mindspore.ops.expand_dims
mindspore.ops.flip
@ -488,6 +489,7 @@ Array Operation
mindspore.ops.unsorted_segment_sum
mindspore.ops.unsqueeze
mindspore.ops.unstack
mindspore.ops.view_as_real
mindspore.ops.vsplit
mindspore.ops.where

View File

@ -196,6 +196,8 @@ from .math_func import (
floormod,
lcm,
tensor_exp,
einsum,
view_as_real,
exp,
tensor_expm1,
expm1,

View File

@ -6048,6 +6048,40 @@ def atleast_3d(inputs):
return [_expand3(arr) for arr in inputs]
def view_as_real(x):
r"""
View a complex Tensor as a real Tensor.
The size of last dimension of the returned real Tensor is 2, and the last dimension is composed of
the real and imaginary components of complex numbers.
Args:
x (Tensor): the input must be a complex Tensor.
Returns:
A real Tensor.
Raises:
TypeError: If the input Tensor is not a complex Tensor.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> x = Tensor([2+1j,2+3j,2-1j,2], ms.complex64),
>>> print(ops.view_as_real(x))
[[ 2. 1.]
[ 2. 3.]
[ 2. -1.]
[ 2. 0.]]
"""
if not is_complex(x):
raise TypeError("For view_as_real, the dtype of input Tensor must be complex.")
real_part = x.real().expand_dims(-1)
imag_part = x.imag().expand_dims(-1)
con = _get_cache_prim(ops.Concat)(-1)
return con((real_part, imag_part))
def vstack(inputs):
r"""
Stacks tensors in sequence vertically.
@ -8695,6 +8729,83 @@ def cross(input, other, dim=None):
return cross_op(input, other)
def einsum(equation, *operands):
r"""
Sums the product of the elements of the input Tensor along
dimensions specified notation based on the Einstein summation convention(Einsum).
You can use this operator to perform diagonal, reducesum, transpose, matmul, mul, inner product operations, etc.
Args:
equation (str): Notation based on the Einstein summation convention, represent the operation you want to do.
the value can contain only letters, commas, ellipsis and arrow.
The letters represent input tensor dimension, commas represent separate tensors, ellipsis indicates
the tensor dimension that you do not care about, the left of the arrow indicates the input tensors,
and the right of it indicates the desired output dimension.
operands (Tensor): Input tensor used for calculation. The dtype of the tensor must be the same.
Returns:
Tensor, the shape of it can be obtained from the `equation` , and the dtype is the same as input tensors.
Raises:
TypeError: If `equation` is invalid, or the `equation` does not match the input tensor.
Supported Platforms:
``GPU``
Examples:
>>> x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
>>> equation = "i->"
>>> output = ops.einsum(equation, x)
>>> print(output)
[7.]
>>>
>>> x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32)
>>> y = Tensor(np.array([2.0, 4.0, 3.0]), mindspore.float32)
>>> equation = "i,i->i"
>>> output = ops.einsum(equation, x, y)
>>> print(output)
[ 2. 8. 12.]
>>>
>>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
>>> y = Tensor(np.array([[2.0, 3.0], [1.0, 2.0], [4.0, 5.0]]), mindspore.float32)
>>> equation = "ij,jk->ik"
>>> output = ops.einsum(equation, x, y)
>>> print(output)
[[16. 22.]
[37. 52.]]
>>>
>>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
>>> equation = "ij->ji"
>>> output = ops.einsum(equation, x)
>>> print(output)
[[1. 4.]
[2. 5.]
[3. 6.]]
>>>
>>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
>>> equation = "ij->j"
>>> output = ops.einsum(equation, x)
>>> print(output)
[5. 7. 9.]
>>>
>>> x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32)
>>> equation = "...->"
>>> output = einsum(equation, x)
>>> print(output)
[21.]
>>>
>>> x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
>>> y = Tensor(np.array([2.0, 4.0, 1.0]), mindspore.float32)
>>> equation = "j,i->ji"
>>> output = ops.einsum(equation, x, y)
>>> print(output)
[[ 2. 4. 1.]
[ 4. 8. 2.]
[ 6. 12. 3.]]
"""
return _get_cache_prim(P.Einsum)(equation)(operands)
def erfinv(input):
r"""
Computes the inverse error function of input. The inverse error function is defined in the range `(-1, 1)` as:
@ -9654,6 +9765,7 @@ __all__ = [
'atleast_2d',
'cartesian_prod',
'atleast_3d',
'view_as_real',
'vstack',
'combinations',
'dist',
@ -9681,6 +9793,7 @@ __all__ = [
'cholesky_inverse',
'conj',
'cross',
'einsum',
'erfinv',
'less_equal',
'cumprod',

View File

@ -0,0 +1,98 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, ops
class Net(nn.Cell):
def __init__(self, equation):
super().__init__()
self.equation = equation
def construct(self, *operands):
return ops.einsum(self.equation, *operands)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_ops_einsum(mode):
"""
Feature: ops.einsum
Description: Verify the result of einsum
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor(np.array([1.0, 2.0, 4.0]), ms.float32)
equation = "i->"
net = Net(equation)
output = net(x)
expect_output = [7.]
assert np.allclose(output.asnumpy(), expect_output)
x = Tensor(np.array([1.0, 2.0, 4.0]), ms.float32)
y = Tensor(np.array([2.0, 4.0, 3.0]), ms.float32)
equation = "i,i->i"
net = Net(equation)
output = net(x, y)
expect_output = [2., 8., 12.]
assert np.allclose(output.asnumpy(), expect_output)
x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), ms.float32)
y = Tensor(np.array([[2.0, 3.0], [1.0, 2.0], [4.0, 5.0]]), ms.float32)
equation = "ij,jk->ik"
net = Net(equation)
output = net(x, y)
expect_output = [[16., 22.],
[37., 52.]]
assert np.allclose(output.asnumpy(), expect_output)
x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), ms.float32)
equation = "ij->ji"
net = Net(equation)
output = net(x)
expect_output = [[1., 4.],
[2., 5.],
[3., 6.]]
assert np.allclose(output.asnumpy(), expect_output)
x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), ms.float32)
equation = "ij->j"
net = Net(equation)
output = net(x)
expect_output = [5., 7., 9.]
assert np.allclose(output.asnumpy(), expect_output)
x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), ms.float32)
equation = "...->"
net = Net(equation)
output = net(x)
expect_output = [21.]
assert np.allclose(output.asnumpy(), expect_output)
x = Tensor(np.array([1.0, 2.0, 3.0]), ms.float32)
y = Tensor(np.array([2.0, 4.0, 1.0]), ms.float32)
equation = "j,i->ji"
net = Net(equation)
output = net(x, y)
expect_output = [[2., 4., 1.],
[4., 8., 2.],
[6., 12., 3.]]
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,47 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, ops
class Net(nn.Cell):
def construct(self, x):
return ops.view_as_real(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_ops_view_as_real(mode):
"""
Feature: ops.view_as_real
Description: Verify the result of view_as_real
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor([2 + 1j, 2 + 3j, 2 - 1j, 2], ms.complex64)
net = Net()
output = net(x)
expect_output = [[2., 1.],
[2., 3.],
[2., -1.],
[2., 0.]]
assert np.allclose(output.asnumpy(), expect_output)