!46438 Add tensor.inner & ops.inner
Merge pull request !46438 from shaojunsong/feature/inner
This commit is contained in:
commit
baf293e4cc
|
@ -338,6 +338,7 @@ Reduction函数
|
|||
mindspore.ops.adjoint
|
||||
mindspore.ops.batch_dot
|
||||
mindspore.ops.dot
|
||||
mindspore.ops.inner
|
||||
mindspore.ops.matmul
|
||||
mindspore.ops.matrix_solve
|
||||
mindspore.ops.mm
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.inner
|
||||
=======================
|
||||
|
||||
.. py:method:: mindspore.Tensor.inner(other)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.inner`。
|
|
@ -134,6 +134,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.index_add
|
||||
mindspore.Tensor.index_fill
|
||||
mindspore.Tensor.init_data
|
||||
mindspore.Tensor.inner
|
||||
mindspore.Tensor.inplace_update
|
||||
mindspore.Tensor.int
|
||||
mindspore.Tensor.inv
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
mindspore.ops.inner
|
||||
====================
|
||||
|
||||
.. py:function:: mindspore.ops.inner(x, other)
|
||||
|
||||
计算两个1D Tensor的点积。对于更高维度来说,计算结果为在最后一维上,逐元素乘法的和。
|
||||
|
||||
..note::
|
||||
如果 `x` 或 `other` 之一是标量,那么相当于 :func:`mindspore.ops.mul(x, other)`。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 第一个输入。
|
||||
- **other** (Tensor) - 第二个输入。
|
||||
|
||||
返回:
|
||||
Tensor,内积的结果。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `x` 和 `other` 都不是标量,且两者的最后一维不相同。
|
|
@ -140,6 +140,7 @@
|
|||
mindspore.Tensor.index_add
|
||||
mindspore.Tensor.index_fill
|
||||
mindspore.Tensor.init_data
|
||||
mindspore.Tensor.inner
|
||||
mindspore.Tensor.inplace_update
|
||||
mindspore.Tensor.int
|
||||
mindspore.Tensor.inv
|
||||
|
|
|
@ -338,6 +338,7 @@ Linear Algebraic Functions
|
|||
mindspore.ops.adjoint
|
||||
mindspore.ops.batch_dot
|
||||
mindspore.ops.dot
|
||||
mindspore.ops.inner
|
||||
mindspore.ops.matmul
|
||||
mindspore.ops.matrix_solve
|
||||
mindspore.ops.mm
|
||||
|
|
|
@ -427,6 +427,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"lstsq", std::string("lstsq")}, // lstsq()
|
||||
{"mvlgamma", std::string("mvlgamma")}, // mvlgamma()
|
||||
{"matmul", std::string("matmul")}, // matmul()
|
||||
{"inner", std::string("inner")}, // inner()
|
||||
{"maximum", std::string("maximum")}, // maximum()
|
||||
{"msort", std::string("msort")}, // msort()
|
||||
{"mm", std::string("mm")}, // mm()
|
||||
|
|
|
@ -3113,6 +3113,11 @@ def matmul(x, y):
|
|||
return F.matmul(x, y)
|
||||
|
||||
|
||||
def inner(x, other):
|
||||
"""Computes the inner product of 2 tensors."""
|
||||
return F.inner(x, other)
|
||||
|
||||
|
||||
def float_bool(x):
|
||||
"""Implementation of `float_bool`."""
|
||||
return x != 0.0
|
||||
|
|
|
@ -4304,6 +4304,13 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('matmul')(self, tensor2)
|
||||
|
||||
def inner(self, other):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.inner`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('inner')(self, other)
|
||||
|
||||
def maximum(self, other):
|
||||
r"""
|
||||
Computes the maximum of input tensors element-wise.
|
||||
|
|
|
@ -300,6 +300,7 @@ from .math_func import (
|
|||
gumbel_softmax,
|
||||
kaiser_window,
|
||||
matmul,
|
||||
inner,
|
||||
baddbmm,
|
||||
cummin,
|
||||
cummax,
|
||||
|
|
|
@ -7258,6 +7258,69 @@ def matmul(x1, x2):
|
|||
return reshape_op(res, shape_out)
|
||||
|
||||
|
||||
def inner(x, other):
|
||||
r"""
|
||||
Computes the dot product of 1D tensors. For higher dimensions, the result will be the summation of the element-wise
|
||||
production along their last dimension.
|
||||
|
||||
Note:
|
||||
If either `x` or `other` is a Tensor scalar, the result is equivalent to mindspore.mul(x, other).
|
||||
|
||||
Args:
|
||||
x (Tensor): First input.
|
||||
other (Tensor): Second input.
|
||||
|
||||
Returns:
|
||||
Tensor, the result of the inner product.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither `x` nor `other` is scalar, and the last dimension of the two input tensors do not match.
|
||||
|
||||
Examples:
|
||||
>>> # case1: 2 1D tensors
|
||||
>>> x = ms.Tensor([1, 2, 3], ms.float32)
|
||||
>>> y = ms.Tensor([4, 5, 6], ms.float32)
|
||||
>>> output = ops.inner(x, y)
|
||||
>>> print(output)
|
||||
32
|
||||
>>> # case2: Tensor scalar and tensor
|
||||
>>> x = ms.Tensor([[[1, 2, 3], [3, 2, 1]], [[4, 5, 6], [4, 5, 6]]], ms.float32)
|
||||
>>> y = ms.Tensor(2, ms.float32)
|
||||
>>> output = ops.inner(x, y)
|
||||
>>> print(output)
|
||||
[[[ 2. 4. 6.]
|
||||
[ 6. 4. 2.]]
|
||||
|
||||
[[ 8. 10. 12.]
|
||||
[ 8. 10. 12.]]]
|
||||
>>> # case3: Two tensors
|
||||
>>> x = ms.Tensor([[[1, 2, 3], [3, 2, 1]], [[4, 5, 6], [4, 5, 6]]], ms.float32)
|
||||
>>> y = ms.Tensor([[2, 3, 4], [4, 3, 2]], ms.float32)
|
||||
>>> output = ops.inner(x, y)
|
||||
>>> print(output)
|
||||
[[[20. 16.]
|
||||
[16. 20.]]
|
||||
|
||||
[[47. 43.]
|
||||
[47. 43.]]]
|
||||
"""
|
||||
x_dim = x.ndim
|
||||
other_dim = other.ndim
|
||||
|
||||
if x_dim == 0 or other_dim == 0:
|
||||
return x * other
|
||||
|
||||
if x_dim == 1 or other_dim == 1:
|
||||
return matmul(x, other)
|
||||
|
||||
x_shape = x.shape
|
||||
other_shape = other.shape
|
||||
if x_shape[-1] != other_shape[-1]:
|
||||
raise ValueError(f"For 'inner', the last dimension of 'x' and 'other' must be the same, \
|
||||
but got x.shape: {x_shape} and other.shape: {other_shape}.")
|
||||
return matmul(x, other.swapaxes(-2, -1))
|
||||
|
||||
|
||||
def bmm(input_x, mat2):
|
||||
"""
|
||||
Computes matrix multiplication between two tensors by batch.
|
||||
|
@ -9271,6 +9334,7 @@ __all__ = [
|
|||
'gumbel_softmax',
|
||||
'kaiser_window',
|
||||
'matmul',
|
||||
'inner',
|
||||
'cummin',
|
||||
'cummax',
|
||||
'cumsum',
|
||||
|
|
|
@ -164,6 +164,7 @@ tensor_operator_registry.register('flatten', P.Flatten)
|
|||
tensor_operator_registry.register('transpose', P.Transpose)
|
||||
tensor_operator_registry.register('broadcast_to', P.BroadcastTo)
|
||||
tensor_operator_registry.register('matmul', matmul)
|
||||
tensor_operator_registry.register('inner', inner)
|
||||
tensor_operator_registry.register('xdivy', P.Xdivy)
|
||||
tensor_operator_registry.register('argmax', P.Argmax)
|
||||
tensor_operator_registry.register('argmin', P.Argmin)
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
output = ops.inner(x, y)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_inner_normal(mode):
|
||||
"""
|
||||
Feature: ops.inner
|
||||
Description: Verify the result of ops.inner
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor([[[1, 2, 3], [3, 2, 1]], [[4, 5, 6], [4, 5, 6]]], ms.float32)
|
||||
y = ms.Tensor([[2, 3, 4], [4, 3, 2]], ms.float32)
|
||||
out = net(x, y)
|
||||
expect_out = np.array([[[20, 16], [16, 20]], [[47, 43], [47, 43]]], dtype=np.float32)
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_inner_with_scalar(mode):
|
||||
"""
|
||||
Feature: ops.inner
|
||||
Description: Verify the result of ops.inner
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor([[[1, 2, 3], [3, 2, 1]], [[4, 5, 6], [4, 5, 6]]], ms.float32)
|
||||
y = ms.Tensor(2, ms.float32)
|
||||
out = net(x, y)
|
||||
expect_out = np.array([[[2, 4, 6], [6, 4, 2]], [[8, 10, 12], [8, 10, 12]]], dtype=np.float32)
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_inner_1d(mode):
|
||||
"""
|
||||
Feature: ops.inner
|
||||
Description: Verify the result of ops.inner
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor([1, 2, 3], ms.float32)
|
||||
y = ms.Tensor([4, 5, 6], ms.float32)
|
||||
out = net(x, y)
|
||||
expect_out = np.array(32, dtype=np.float32)
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
output = x.inner(y)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_inner_normal(mode):
|
||||
"""
|
||||
Feature: tensor.inner
|
||||
Description: Verify the result of tensor.inner
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor([[[1, 2, 3], [3, 2, 1]], [[4, 5, 6], [4, 5, 6]]], ms.float32)
|
||||
y = ms.Tensor([[2, 3, 4], [4, 3, 2]], ms.float32)
|
||||
out = net(x, y)
|
||||
expect_out = np.array([[[20, 16], [16, 20]], [[47, 43], [47, 43]]], dtype=np.float32)
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_inner_with_scalar(mode):
|
||||
"""
|
||||
Feature: tensor.inner
|
||||
Description: Verify the result of tensor.inner
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor([[[1, 2, 3], [3, 2, 1]], [[4, 5, 6], [4, 5, 6]]], ms.float32)
|
||||
y = ms.Tensor(2, ms.float32)
|
||||
out = net(x, y)
|
||||
expect_out = np.array([[[2, 4, 6], [6, 4, 2]], [[8, 10, 12], [8, 10, 12]]], dtype=np.float32)
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_inner_1d(mode):
|
||||
"""
|
||||
Feature: tensor.inner
|
||||
Description: Verify the result of tensor.inner
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = ms.Tensor([1, 2, 3], ms.float32)
|
||||
y = ms.Tensor([4, 5, 6], ms.float32)
|
||||
out = net(x, y)
|
||||
expect_out = np.array(32, dtype=np.float32)
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
Loading…
Reference in New Issue