!46604 tensor_sum_to_size_master

Merge pull request !46604 from yide12/tensor_sum_to_size_master
This commit is contained in:
i-robot 2022-12-12 11:29:01 +00:00 committed by Gitee
commit 6775552838
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
15 changed files with 202 additions and 7 deletions

View File

@ -266,6 +266,7 @@ mindspore.ops
mindspore.ops.square mindspore.ops.square
mindspore.ops.sub mindspore.ops.sub
mindspore.ops.subtract mindspore.ops.subtract
mindspore.ops.sum_to_size
mindspore.ops.svd mindspore.ops.svd
mindspore.ops.t mindspore.ops.t
mindspore.ops.tan mindspore.ops.tan

View File

@ -0,0 +1,6 @@
mindspore.Tensor.sum_to_size
============================
.. py:method:: mindspore.Tensor.sum_to_size(*size)
详情请参考 :func:`mindspore.ops.sum_to_size`

View File

@ -251,6 +251,7 @@ mindspore.Tensor
mindspore.Tensor.sub mindspore.Tensor.sub
mindspore.Tensor.subtract mindspore.Tensor.subtract
mindspore.Tensor.sum mindspore.Tensor.sum
mindspore.Tensor.sum_to_size
mindspore.Tensor.svd mindspore.Tensor.svd
mindspore.Tensor.swapaxes mindspore.Tensor.swapaxes
mindspore.Tensor.T mindspore.Tensor.T

View File

@ -0,0 +1,16 @@
mindspore.ops.sum_to_size
=========================
.. py:function:: mindspore.ops.sum_to_size(x, *size)
将Tensor `x` 加和成 `size``size` 必须可以扩展到Tensor的大小。
参数:
- **x** (Tensor) - 求和的Tensor。
- **size** (Union[tuple(int), int]) - 期望输出Tensor的shape。
返回:
Tensor根据 `size``x` 进行求和的结果。
异常:
- **ValueError** - `size` 不能扩展成 `x` 的大小。

View File

@ -257,6 +257,7 @@
mindspore.Tensor.sub mindspore.Tensor.sub
mindspore.Tensor.subtract mindspore.Tensor.subtract
mindspore.Tensor.sum mindspore.Tensor.sum
mindspore.Tensor.sum_to_size
mindspore.Tensor.svd mindspore.Tensor.svd
mindspore.Tensor.swapaxes mindspore.Tensor.swapaxes
mindspore.Tensor.T mindspore.Tensor.T

View File

@ -266,6 +266,7 @@ Element-by-Element Operations
mindspore.ops.square mindspore.ops.square
mindspore.ops.sub mindspore.ops.sub
mindspore.ops.subtract mindspore.ops.subtract
mindspore.ops.sum_to_size
mindspore.ops.svd mindspore.ops.svd
mindspore.ops.t mindspore.ops.t
mindspore.ops.tan mindspore.ops.tan

View File

@ -329,6 +329,7 @@ BuiltInTypeMap &GetMethodMap() {
{"true_divide", std::string("true_divide")}, // true_divide() {"true_divide", std::string("true_divide")}, // true_divide()
{"triu", std::string("triu")}, // triu() {"triu", std::string("triu")}, // triu()
{"subtract", std::string("subtract")}, // true_divide() {"subtract", std::string("subtract")}, // true_divide()
{"sum_to_size", std::string("sum_to_size")}, // sum_to_size()
{"exp", std::string("exp")}, // P.Exp() {"exp", std::string("exp")}, // P.Exp()
{"repeat", std::string("repeat")}, // C.repeat_elements {"repeat", std::string("repeat")}, // C.repeat_elements
{"bernoulli", prim::kPrimBernoulli}, // P.Bernoulli() {"bernoulli", prim::kPrimBernoulli}, // P.Bernoulli()

View File

@ -1205,8 +1205,8 @@ def permute(x, *dims):
""" """
if dims is None: if dims is None:
raise ValueError(f"For Tensor.permute, the dims must not be none.") raise ValueError(f"For Tensor.permute, the dims must not be none.")
if len(dims) == 1: if len(dims) == 1 and isinstance(dims[0], tuple):
return F.permute(x, *dims) return F.permute(x, dims[0])
return F.permute(x, dims) return F.permute(x, dims)
@ -1989,6 +1989,11 @@ def sum(x, axis=None, dtype=None, keepdims=False, initial=None): # pylint: disa
return res.astype(dtype) return res.astype(dtype)
def sum_to_size(x, *size):
"""For details, please refer to :func:`mindspore.ops.sum_to_size`."""
return F.sum_to_size(x, *size)
def repeat(x, repeats, axis=None): def repeat(x, repeats, axis=None):
""" """
Repeat elements of an array. Repeat elements of an array.

View File

@ -1849,8 +1849,8 @@ class Tensor(Tensor_):
self._init_check() self._init_check()
if not dims: if not dims:
raise ValueError(f"For Tensor.permute, the dims must not be none.") raise ValueError(f"For Tensor.permute, the dims must not be none.")
if len(dims) == 1: if len(dims) == 1 and isinstance(dims[0], tuple):
return tensor_operator_registry.get("permute")(self, *dims) return tensor_operator_registry.get("permute")(self, dims[0])
return tensor_operator_registry.get("permute")(self, dims) return tensor_operator_registry.get("permute")(self, dims)
def positive(self): def positive(self):
@ -3067,6 +3067,11 @@ class Tensor(Tensor_):
res += initial res += initial
return res.astype(dtype) return res.astype(dtype)
def sum_to_size(self, *size):
"""For details, please refer to :func:`mindspore.ops.sum_to_size`."""
self._init_check()
return tensor_operator_registry.get("sum_to_size")(self, *size)
def repeat(self, repeats, axis=None): def repeat(self, repeats, axis=None):
""" """
Repeat elements of a tensor. Repeat elements of a tensor.

View File

@ -2386,9 +2386,8 @@ class CTCLoss(LossBase):
def construct(self, log_probs, targets, input_lengths, target_lengths): def construct(self, log_probs, targets, input_lengths, target_lengths):
if len(log_probs.shape) == 2: if len(log_probs.shape) == 2:
n, c = log_probs.shape log_probs = log_probs.expand_dims(-2)
log_probs = log_probs.reshape((n, 1, c)) targets = targets.expand_dims(0)
targets = targets.reshape(1, targets.shape[0])
if isinstance(input_lengths, int): if isinstance(input_lengths, int):
input_lengths = Tensor([input_lengths], mstype.int32) input_lengths = Tensor([input_lengths], mstype.int32)
else: else:

View File

@ -172,6 +172,7 @@ from .math_func import (
real, real,
sub, sub,
subtract, subtract,
sum_to_size,
sqrt, sqrt,
square, square,
tensor_mul, tensor_mul,

View File

@ -809,6 +809,48 @@ def subtract(x, other, *, alpha=1):
return tensor_sub(x, alpha * other) return tensor_sub(x, alpha * other)
def sum_to_size(x, *size):
"""
Sum `x` Tensor to the `size`. `size` must be expandable to the Tensor size.
Args:
x (Tensor): The Tensor to be summed.
size (Union[tuple(int), int]): The expected shape of output Tensor.
Returns:
Tensor, the sum result of `x` according to the `size`.
Raises:
ValueError: If `size` is not expandable to the size of `x`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.random.randn(3, 3, 3, 3, 3, 3), mindspore.float32)
>>> output = ops.sum_to_size(x, (1, 3, 1, 3))
>>> print(output.shape)
(1, 3, 1, 3)
"""
if len(size) == 1 and isinstance(size[0], tuple):
size = size[0]
shape_x = x.shape
if len(size) > x.ndim:
raise ValueError(f"For sum_to_size, size {size} is not expandable to the tensor size {shape_x}.")
if len(size) < x.ndim:
pre_axis = tuple([axis for axis in range(x.ndim - len(size))])
x = x.sum(pre_axis)
axes = []
for i, element in enumerate(size):
if element != x.shape[i] and element == 1:
axes.append(i)
elif element != x.shape[i]:
raise ValueError(f"For sum_to_size, size {size} is not expandable to the tensor size {shape_x}.")
if axes:
return x.sum(tuple(axes), keepdims=True)
return x
def true_divide(dividend, divisor): def true_divide(dividend, divisor):
r""" r"""
Alias for Tensor.div() with :math:`rounding\_mode=None`. Alias for Tensor.div() with :math:`rounding\_mode=None`.
@ -9363,6 +9405,7 @@ __all__ = [
'tensor_sub', 'tensor_sub',
'sub', 'sub',
'subtract', 'subtract',
'sum_to_size',
'tensor_mul', 'tensor_mul',
'mul', 'mul',
'multiply', 'multiply',

View File

@ -140,6 +140,7 @@ tensor_operator_registry.register('bincount', bincount)
tensor_operator_registry.register('sqrt', sqrt) tensor_operator_registry.register('sqrt', sqrt)
tensor_operator_registry.register('square', square) tensor_operator_registry.register('square', square)
tensor_operator_registry.register('sub', sub) tensor_operator_registry.register('sub', sub)
tensor_operator_registry.register('sum_to_size', sum_to_size)
tensor_operator_registry.register('triu', Triu) tensor_operator_registry.register('triu', Triu)
tensor_operator_registry.register('tan', P.Tan) tensor_operator_registry.register('tan', P.Tan)
tensor_operator_registry.register('t', t) tensor_operator_registry.register('t', t)

View File

@ -0,0 +1,57 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, ops
class Net(nn.Cell):
def construct(self, x):
return ops.sum_to_size(x, (3, 1, 3))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_ops_sum_to_size(mode):
"""
Feature: ops.sum_to_size
Description: Verify the result of sum_to_size
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor([[[24, 20, 39],
[79, 67, 43],
[62, 0, 95]],
[[74, 5, 33],
[0, 35, 78],
[67, 0, 29]],
[[45, 42, 77],
[70, 61, 72],
[23, 82, 47]]], ms.float32)
net = Net()
output = net(x)
expect_output = [[[165, 87, 177]],
[[141, 40, 140]],
[[138, 185, 196]]]
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -0,0 +1,57 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x):
return x.sum_to_size((3, 1, 3))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_sum_to_size(mode):
"""
Feature: Tensor.sum_to_size
Description: Verify the result of sum_to_size
Expectation: success
"""
ms.set_context(mode=mode)
x = Tensor([[[24, 20, 39],
[79, 67, 43],
[62, 0, 95]],
[[74, 5, 33],
[0, 35, 78],
[67, 0, 29]],
[[45, 42, 77],
[70, 61, 72],
[23, 82, 47]]], ms.float32)
net = Net()
output = net(x)
expect_output = [[[165, 87, 177]],
[[141, 40, 140]],
[[138, 185, 196]]]
assert np.allclose(output.asnumpy(), expect_output)