tensor_diagflat_master

This commit is contained in:
yide12 2023-01-05 20:21:23 +08:00
parent 44cfcca6f4
commit 6374215d34
16 changed files with 207 additions and 3 deletions

View File

@ -432,6 +432,7 @@ Array操作
mindspore.ops.concat
mindspore.ops.count_nonzero
mindspore.ops.diag
mindspore.ops.diagflat
mindspore.ops.diagonal
mindspore.ops.dyn_shape
mindspore.ops.dsplit

View File

@ -0,0 +1,6 @@
mindspore.Tensor.diagflat
=========================
.. py:method:: mindspore.Tensor.diagflat(offset=0)
详情请参考 :func:`mindspore.ops.diagflat`

View File

@ -89,6 +89,7 @@ mindspore.Tensor
mindspore.Tensor.deg2rad
mindspore.Tensor.det
mindspore.Tensor.diag
mindspore.Tensor.diagflat
mindspore.Tensor.diagonal
mindspore.Tensor.digamma
mindspore.Tensor.div

View File

@ -0,0 +1,19 @@
mindspore.ops.diagflat
======================
.. py:function:: mindspore.ops.diagflat(x, offset=0)
创建一个二维Tensor用展开后的输入作为它的对角线。
参数:
- **x** (Tensor) - 输入Tensor展开后作为输出的对角线。
- **offset** (int, 可选) - `offset` 控制选择哪条对角线。默认值0。
- 当 `offset` 是0时选择的对角线是主对角线。
- 当 `offset` 大于0时选择的对角线在主对角线上。
- 当 `offset` 小于0时选择的对角线在主对角线下。
返回:
二维Tensor。
异常:
- **TypeError** - `x` 不是Tensor。

View File

@ -9,7 +9,7 @@ mindspore.ops.pixel_shuffle
参数:
- **x** (Tensor) - Tensorshape为 :math:`(*, C \times r^2, H, W)``x` 的维度需要大于2并且倒数第三维length可以被 `upscale_factor` 的平方整除。
- **upscale_factor** (int) - 增加空间分辨率的因子,是正整数。
- **upscale_factor** (int) - 增加空间分辨率的因子,是正整数。
返回:
- **output** (Tensor) - Tensorshape为 :math:`(*, C, H \times r, W \times r)`

View File

@ -95,6 +95,7 @@
mindspore.Tensor.deg2rad
mindspore.Tensor.det
mindspore.Tensor.diag
mindspore.Tensor.diagflat
mindspore.Tensor.diagonal
mindspore.Tensor.digamma
mindspore.Tensor.div

View File

@ -432,6 +432,7 @@ Array Operation
mindspore.ops.concat
mindspore.ops.count_nonzero
mindspore.ops.diag
mindspore.ops.diagflat
mindspore.ops.diagonal
mindspore.ops.dsplit
mindspore.ops.dyn_shape

View File

@ -351,6 +351,7 @@ BuiltInTypeMap &GetMethodMap() {
{"unique_consecutive", std::string("unique_consecutive")}, // UniqueConsecutive()
{"unique_with_pad", std::string("unique_with_pad")}, // P.UniqueWithPad()
{"diag", std::string("diag")}, // P.Diag()
{"diagflat", std::string("diagflat")}, // diagflat()
{"digamma", std::string("digamma")}, // digamma()
{"lgamma", std::string("lgamma")}, // lgamma()
{"adaptive_max_pool2d", std::string("adaptive_max_pool2d")}, // P.AdaptiveMaxPool2D

View File

@ -2997,6 +2997,13 @@ def diag(x):
return F.diag(x)
def diagflat(x, offset=0):
"""
Creates a two-dimensional Tensor with the flattened input as a diagonal.
"""
return F.diagflat(x, offset)
def masked_select(x, mask):
"""
Returns a new 1-D Tensor which indexes the input tensor according to the boolean mask.

View File

@ -3349,6 +3349,13 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('diag')()(self)
def diagflat(self, offset=0):
r"""
For details, please refer to :func:`mindspore.ops.diagflat`.
"""
self._init_check()
return tensor_operator_registry.get('diagflat')(self, offset)
def xdivy(self, y):
r"""
For details, please refer to :func:`mindspore.ops.xdivy`.

View File

@ -111,6 +111,7 @@ from .array_func import (
matrix_diag_part,
matrix_set_diag,
diag,
diagflat,
masked_select,
where,
meshgrid,

View File

@ -3840,7 +3840,7 @@ def matrix_diag_part(x, k=0, padding_value=0, align="RIGHT_LEFT"):
return matrix_diag_part_v3(x, k, padding_value)
def matrix_set_diag(x, diagonal, k=0, align="RIGHT_LEFT"):
def matrix_set_diag(x, diagonal, k=0, align="RIGHT_LEFT"): # pylint: disable=redefined-outer-name
r"""
Returns a batched matrix tensor with new batched diagonal values.
Given x and diagonal, this operation returns a tensor with the same shape and values as x, except for the specified
@ -4651,6 +4651,57 @@ def diag(input_x):
return diag_(input_x)
def diagflat(x, offset=0):
r"""
Creates a two-dimensional Tensor with the flattened input as a diagonal.
Args:
x (Tensor): Input Tensor, which is flattened and set as the diagonal of the output.
offset (int, optional): `offset` controls which diagonal to consider. Default: 0.
- When `offset` is zero, the diagonal chosen is the main diagonal.
- When `offset` is greater than zero, the diagonal chosen is above the main diagonal.
- When `offset` is less than zero, the diagonal chosen is below the main diagonal.
Returns:
The 2-D Tensor.
Raises:
TypeError: If `x` is not a tensor.
TypeError: If `offset` is not an integer.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor([1, 2], mindspore.float32)
>>> output = ops.diagflat(x, 1)
>>> print(output)
[[0. 1. 0.]
[0. 0. 2.]
[0. 0. 0.]]
"""
if not isinstance(x, (Tensor, Tensor_)):
raise TypeError(f"For diagflat, the input x must be tensor, but got {type(x)}")
if not isinstance(offset, int):
raise TypeError(f"For diagflat, the offset must be int, but got {type(offset)}")
offset_abs = abs(offset)
if x.size == 0:
return zeros((offset_abs, offset_abs), x.dtype)
x = x.ravel()
res = diag(x)
if offset != 0:
pad_y = zeros((x.size + offset_abs, offset_abs), x.dtype)
pad_x = zeros((offset_abs, x.size), x.dtype)
if offset < 0:
res = cat((pad_x, res), axis=0)
res = cat((res, pad_y), axis=1)
else:
res = cat((res, pad_x), axis=0)
res = cat((pad_y, res), axis=1)
return res
def col2im(input_x, output_size, kernel_size, dilation, padding_value, stride):
"""
Combines an array of sliding local blocks into a large containing tensor.
@ -5965,6 +6016,7 @@ __all__ = [
'matrix_diag_part',
'matrix_set_diag',
'diag',
'diagflat',
'meshgrid',
'affine_grid',
'meshgrid',

View File

@ -4218,7 +4218,7 @@ def hardtanh(x, min_val=-1.0, max_val=1.0):
def huber_loss(x, target, reduction='mean', delta=1.0):
r"""
huber_loss calculate the error between the predicted value and the target value.
huber_loss calculates the error between the predicted value and the target value.
It has the advantages of both l1_loss and mse_loss.
Assuming that the :math:`x` and :math:`y` are 1-D Tensor, length :math:`N`, the reduction parameter is set to "none"

View File

@ -234,6 +234,7 @@ tensor_operator_registry.register('hypot', hypot)
tensor_operator_registry.register('soft_shrink', P.SoftShrink)
tensor_operator_registry.register('svd', linalg_ops.Svd)
tensor_operator_registry.register('diag', P.Diag)
tensor_operator_registry.register('diagflat', diagflat)
tensor_operator_registry.register('unique_consecutive', UniqueConsecutive)
tensor_operator_registry.register('unique_with_pad', P.UniqueWithPad)
tensor_operator_registry.register('inplace_update', P.InplaceUpdate)

View File

@ -0,0 +1,57 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import ops
class Net(nn.Cell):
def construct(self, x):
out1 = ops.diagflat(x, -1)
out2 = ops.diagflat(x, 1)
return out1, out2
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_ops_diagflat(mode):
"""
Feature: ops.diagflat
Description: Verify the result of diagflat
Expectation: success
"""
ms.set_context(mode=mode)
x = ms.Tensor([-0.5, 0.5, 3], ms.float32)
net = Net()
output1, output2 = net(x)
expect_output1 = [[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[-5.00000000e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[0.00000000e+00, 5.00000000e-01, 0.00000000e+00, 0.00000000e+00],
[0.00000000e+00, 0.00000000e+00, 3.00000000e+00, 0.00000000e+00]]
expect_output2 = [[0.00000000e+00, -5.00000000e-01, 0.00000000e+00, 0.00000000e+00],
[0.00000000e+00, 0.00000000e+00, 5.00000000e-01, 0.00000000e+00],
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 3.00000000e+00],
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]
assert np.allclose(output1.asnumpy(), expect_output1)
assert np.allclose(output2.asnumpy(), expect_output2)

View File

@ -0,0 +1,49 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x):
return x.diagflat(1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_diagflat(mode):
"""
Feature: tensor.diagflat
Description: Verify the result of diagflat
Expectation: success
"""
ms.set_context(mode=mode)
x = ms.Tensor([-0.5, 0.5, 3], ms.float32)
net = Net()
output = net(x)
expect_output = [[0.00000000e+00, -5.00000000e-01, 0.00000000e+00, 0.00000000e+00],
[0.00000000e+00, 0.00000000e+00, 5.00000000e-01, 0.00000000e+00],
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 3.00000000e+00],
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]
assert np.allclose(output.asnumpy(), expect_output)