forked from mindspore-Ecosystem/mindspore
tensor_sum_to_size_master
This commit is contained in:
parent
d7a19bf0fe
commit
06de9234d3
|
@ -263,6 +263,7 @@ mindspore.ops
|
|||
mindspore.ops.square
|
||||
mindspore.ops.sub
|
||||
mindspore.ops.subtract
|
||||
mindspore.ops.sum_to_size
|
||||
mindspore.ops.svd
|
||||
mindspore.ops.tan
|
||||
mindspore.ops.true_divide
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.sum_to_size
|
||||
============================
|
||||
|
||||
.. py:method:: mindspore.Tensor.sum_to_size(*size)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.sum_to_size`。
|
|
@ -248,6 +248,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.sub
|
||||
mindspore.Tensor.subtract
|
||||
mindspore.Tensor.sum
|
||||
mindspore.Tensor.sum_to_size
|
||||
mindspore.Tensor.svd
|
||||
mindspore.Tensor.swapaxes
|
||||
mindspore.Tensor.T
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
mindspore.ops.sum_to_size
|
||||
=========================
|
||||
|
||||
.. py:function:: mindspore.ops.sum_to_size(x, *size)
|
||||
|
||||
将Tensor `x` 加和成 `size`。`size` 必须可以扩展到Tensor的大小。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 求和的Tensor。
|
||||
- **size** (Union[tuple(int), int]) - 期望输出Tensor的shape。
|
||||
|
||||
返回:
|
||||
Tensor,根据 `size` 对 `x` 进行求和的结果。
|
||||
|
||||
异常:
|
||||
- **ValueError** - `size` 不能扩展成 `x` 的大小。
|
|
@ -254,6 +254,7 @@
|
|||
mindspore.Tensor.sub
|
||||
mindspore.Tensor.subtract
|
||||
mindspore.Tensor.sum
|
||||
mindspore.Tensor.sum_to_size
|
||||
mindspore.Tensor.svd
|
||||
mindspore.Tensor.swapaxes
|
||||
mindspore.Tensor.T
|
||||
|
|
|
@ -263,6 +263,7 @@ Element-by-Element Operations
|
|||
mindspore.ops.square
|
||||
mindspore.ops.sub
|
||||
mindspore.ops.subtract
|
||||
mindspore.ops.sum_to_size
|
||||
mindspore.ops.svd
|
||||
mindspore.ops.tan
|
||||
mindspore.ops.true_divide
|
||||
|
|
|
@ -328,6 +328,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"true_divide", std::string("true_divide")}, // true_divide()
|
||||
{"triu", std::string("triu")}, // triu()
|
||||
{"subtract", std::string("subtract")}, // true_divide()
|
||||
{"sum_to_size", std::string("sum_to_size")}, // sum_to_size()
|
||||
{"exp", std::string("exp")}, // P.Exp()
|
||||
{"repeat", std::string("repeat")}, // C.repeat_elements
|
||||
{"bernoulli", prim::kPrimBernoulli}, // P.Bernoulli()
|
||||
|
|
|
@ -1205,8 +1205,8 @@ def permute(x, *dims):
|
|||
"""
|
||||
if dims is None:
|
||||
raise ValueError(f"For Tensor.permute, the dims must not be none.")
|
||||
if len(dims) == 1:
|
||||
return F.permute(x, *dims)
|
||||
if len(dims) == 1 and isinstance(dims[0], tuple):
|
||||
return F.permute(x, dims[0])
|
||||
return F.permute(x, dims)
|
||||
|
||||
|
||||
|
@ -1989,6 +1989,11 @@ def sum(x, axis=None, dtype=None, keepdims=False, initial=None): # pylint: disa
|
|||
return res.astype(dtype)
|
||||
|
||||
|
||||
def sum_to_size(x, *size):
|
||||
"""For details, please refer to :func:`mindspore.ops.sum_to_size`."""
|
||||
return F.sum_to_size(x, *size)
|
||||
|
||||
|
||||
def repeat(x, repeats, axis=None):
|
||||
"""
|
||||
Repeat elements of an array.
|
||||
|
|
|
@ -1819,8 +1819,8 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
if not dims:
|
||||
raise ValueError(f"For Tensor.permute, the dims must not be none.")
|
||||
if len(dims) == 1:
|
||||
return tensor_operator_registry.get("permute")(self, *dims)
|
||||
if len(dims) == 1 and isinstance(dims[0], tuple):
|
||||
return tensor_operator_registry.get("permute")(self, dims[0])
|
||||
return tensor_operator_registry.get("permute")(self, dims)
|
||||
|
||||
def positive(self):
|
||||
|
@ -3037,6 +3037,11 @@ class Tensor(Tensor_):
|
|||
res += initial
|
||||
return res.astype(dtype)
|
||||
|
||||
def sum_to_size(self, *size):
|
||||
"""For details, please refer to :func:`mindspore.ops.sum_to_size`."""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get("sum_to_size")(self, *size)
|
||||
|
||||
def repeat(self, repeats, axis=None):
|
||||
"""
|
||||
Repeat elements of a tensor.
|
||||
|
|
|
@ -2386,9 +2386,8 @@ class CTCLoss(LossBase):
|
|||
|
||||
def construct(self, log_probs, targets, input_lengths, target_lengths):
|
||||
if len(log_probs.shape) == 2:
|
||||
n, c = log_probs.shape
|
||||
log_probs = log_probs.reshape((n, 1, c))
|
||||
targets = targets.reshape(1, targets.shape[0])
|
||||
log_probs = log_probs.expand_dims(-2)
|
||||
targets = targets.expand_dims(0)
|
||||
if isinstance(input_lengths, int):
|
||||
input_lengths = Tensor([input_lengths], mstype.int32)
|
||||
else:
|
||||
|
|
|
@ -172,6 +172,7 @@ from .math_func import (
|
|||
real,
|
||||
sub,
|
||||
subtract,
|
||||
sum_to_size,
|
||||
sqrt,
|
||||
square,
|
||||
tensor_mul,
|
||||
|
|
|
@ -809,6 +809,48 @@ def subtract(x, other, *, alpha=1):
|
|||
return tensor_sub(x, alpha * other)
|
||||
|
||||
|
||||
def sum_to_size(x, *size):
|
||||
"""
|
||||
Sum `x` Tensor to the `size`. `size` must be expandable to the Tensor size.
|
||||
|
||||
Args:
|
||||
x (Tensor): The Tensor to be summed.
|
||||
size (Union[tuple(int), int]): The expected shape of output Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, the sum result of `x` according to the `size`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `size` is not expandable to the size of `x`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.random.randn(3, 3, 3, 3, 3, 3), mindspore.float32)
|
||||
>>> output = ops.sum_to_size(x, (1, 3, 1, 3))
|
||||
>>> print(output.shape)
|
||||
(1, 3, 1, 3)
|
||||
"""
|
||||
if len(size) == 1 and isinstance(size[0], tuple):
|
||||
size = size[0]
|
||||
shape_x = x.shape
|
||||
if len(size) > x.ndim:
|
||||
raise ValueError(f"For sum_to_size, size {size} is not expandable to the tensor size {shape_x}.")
|
||||
if len(size) < x.ndim:
|
||||
pre_axis = tuple([axis for axis in range(x.ndim - len(size))])
|
||||
x = x.sum(pre_axis)
|
||||
axes = []
|
||||
for i, element in enumerate(size):
|
||||
if element != x.shape[i] and element == 1:
|
||||
axes.append(i)
|
||||
elif element != x.shape[i]:
|
||||
raise ValueError(f"For sum_to_size, size {size} is not expandable to the tensor size {shape_x}.")
|
||||
if axes:
|
||||
return x.sum(tuple(axes), keepdims=True)
|
||||
return x
|
||||
|
||||
|
||||
def true_divide(dividend, divisor):
|
||||
r"""
|
||||
Alias for Tensor.div() with :math:`rounding\_mode=None`.
|
||||
|
@ -9217,6 +9259,7 @@ __all__ = [
|
|||
'tensor_sub',
|
||||
'sub',
|
||||
'subtract',
|
||||
'sum_to_size',
|
||||
'tensor_mul',
|
||||
'mul',
|
||||
'multiply',
|
||||
|
|
|
@ -140,6 +140,7 @@ tensor_operator_registry.register('bincount', bincount)
|
|||
tensor_operator_registry.register('sqrt', sqrt)
|
||||
tensor_operator_registry.register('square', square)
|
||||
tensor_operator_registry.register('sub', sub)
|
||||
tensor_operator_registry.register('sum_to_size', sum_to_size)
|
||||
tensor_operator_registry.register('triu', Triu)
|
||||
tensor_operator_registry.register('tan', P.Tan)
|
||||
tensor_operator_registry.register('acos', acos)
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return ops.sum_to_size(x, (3, 1, 3))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_ops_sum_to_size(mode):
|
||||
"""
|
||||
Feature: ops.sum_to_size
|
||||
Description: Verify the result of sum_to_size
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor([[[24, 20, 39],
|
||||
[79, 67, 43],
|
||||
[62, 0, 95]],
|
||||
[[74, 5, 33],
|
||||
[0, 35, 78],
|
||||
[67, 0, 29]],
|
||||
[[45, 42, 77],
|
||||
[70, 61, 72],
|
||||
[23, 82, 47]]], ms.float32)
|
||||
net = Net()
|
||||
output = net(x)
|
||||
expect_output = [[[165, 87, 177]],
|
||||
[[141, 40, 140]],
|
||||
[[138, 185, 196]]]
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x.sum_to_size((3, 1, 3))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_tensor_sum_to_size(mode):
|
||||
"""
|
||||
Feature: Tensor.sum_to_size
|
||||
Description: Verify the result of sum_to_size
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor([[[24, 20, 39],
|
||||
[79, 67, 43],
|
||||
[62, 0, 95]],
|
||||
[[74, 5, 33],
|
||||
[0, 35, 78],
|
||||
[67, 0, 29]],
|
||||
[[45, 42, 77],
|
||||
[70, 61, 72],
|
||||
[23, 82, 47]]], ms.float32)
|
||||
net = Net()
|
||||
output = net(x)
|
||||
expect_output = [[[165, 87, 177]],
|
||||
[[141, 40, 140]],
|
||||
[[138, 185, 196]]]
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
Loading…
Reference in New Issue