tensor_diagflat_master
This commit is contained in:
parent
44cfcca6f4
commit
6374215d34
|
@ -432,6 +432,7 @@ Array操作
|
|||
mindspore.ops.concat
|
||||
mindspore.ops.count_nonzero
|
||||
mindspore.ops.diag
|
||||
mindspore.ops.diagflat
|
||||
mindspore.ops.diagonal
|
||||
mindspore.ops.dyn_shape
|
||||
mindspore.ops.dsplit
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.diagflat
|
||||
=========================
|
||||
|
||||
.. py:method:: mindspore.Tensor.diagflat(offset=0)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.diagflat`。
|
|
@ -89,6 +89,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.deg2rad
|
||||
mindspore.Tensor.det
|
||||
mindspore.Tensor.diag
|
||||
mindspore.Tensor.diagflat
|
||||
mindspore.Tensor.diagonal
|
||||
mindspore.Tensor.digamma
|
||||
mindspore.Tensor.div
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
mindspore.ops.diagflat
|
||||
======================
|
||||
|
||||
.. py:function:: mindspore.ops.diagflat(x, offset=0)
|
||||
|
||||
创建一个二维Tensor,用展开后的输入作为它的对角线。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 输入Tensor,展开后作为输出的对角线。
|
||||
- **offset** (int, 可选) - `offset` 控制选择哪条对角线。默认值:0。
|
||||
- 当 `offset` 是0时,选择的对角线是主对角线。
|
||||
- 当 `offset` 大于0时,选择的对角线在主对角线上。
|
||||
- 当 `offset` 小于0时,选择的对角线在主对角线下。
|
||||
|
||||
返回:
|
||||
二维Tensor。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 不是Tensor。
|
|
@ -9,7 +9,7 @@ mindspore.ops.pixel_shuffle
|
|||
|
||||
参数:
|
||||
- **x** (Tensor) - Tensor,shape为 :math:`(*, C \times r^2, H, W)` 。 `x` 的维度需要大于2,并且倒数第三维length可以被 `upscale_factor` 的平方整除。
|
||||
- **upscale_factor** (int) - 增加空间分辨率的因子,是正整数。。
|
||||
- **upscale_factor** (int) - 增加空间分辨率的因子,是正整数。
|
||||
|
||||
返回:
|
||||
- **output** (Tensor) - Tensor,shape为 :math:`(*, C, H \times r, W \times r)` 。
|
||||
|
|
|
@ -95,6 +95,7 @@
|
|||
mindspore.Tensor.deg2rad
|
||||
mindspore.Tensor.det
|
||||
mindspore.Tensor.diag
|
||||
mindspore.Tensor.diagflat
|
||||
mindspore.Tensor.diagonal
|
||||
mindspore.Tensor.digamma
|
||||
mindspore.Tensor.div
|
||||
|
|
|
@ -432,6 +432,7 @@ Array Operation
|
|||
mindspore.ops.concat
|
||||
mindspore.ops.count_nonzero
|
||||
mindspore.ops.diag
|
||||
mindspore.ops.diagflat
|
||||
mindspore.ops.diagonal
|
||||
mindspore.ops.dsplit
|
||||
mindspore.ops.dyn_shape
|
||||
|
|
|
@ -351,6 +351,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"unique_consecutive", std::string("unique_consecutive")}, // UniqueConsecutive()
|
||||
{"unique_with_pad", std::string("unique_with_pad")}, // P.UniqueWithPad()
|
||||
{"diag", std::string("diag")}, // P.Diag()
|
||||
{"diagflat", std::string("diagflat")}, // diagflat()
|
||||
{"digamma", std::string("digamma")}, // digamma()
|
||||
{"lgamma", std::string("lgamma")}, // lgamma()
|
||||
{"adaptive_max_pool2d", std::string("adaptive_max_pool2d")}, // P.AdaptiveMaxPool2D
|
||||
|
|
|
@ -2997,6 +2997,13 @@ def diag(x):
|
|||
return F.diag(x)
|
||||
|
||||
|
||||
def diagflat(x, offset=0):
|
||||
"""
|
||||
Creates a two-dimensional Tensor with the flattened input as a diagonal.
|
||||
"""
|
||||
return F.diagflat(x, offset)
|
||||
|
||||
|
||||
def masked_select(x, mask):
|
||||
"""
|
||||
Returns a new 1-D Tensor which indexes the input tensor according to the boolean mask.
|
||||
|
|
|
@ -3349,6 +3349,13 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('diag')()(self)
|
||||
|
||||
def diagflat(self, offset=0):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.diagflat`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('diagflat')(self, offset)
|
||||
|
||||
def xdivy(self, y):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.xdivy`.
|
||||
|
|
|
@ -111,6 +111,7 @@ from .array_func import (
|
|||
matrix_diag_part,
|
||||
matrix_set_diag,
|
||||
diag,
|
||||
diagflat,
|
||||
masked_select,
|
||||
where,
|
||||
meshgrid,
|
||||
|
|
|
@ -3840,7 +3840,7 @@ def matrix_diag_part(x, k=0, padding_value=0, align="RIGHT_LEFT"):
|
|||
return matrix_diag_part_v3(x, k, padding_value)
|
||||
|
||||
|
||||
def matrix_set_diag(x, diagonal, k=0, align="RIGHT_LEFT"):
|
||||
def matrix_set_diag(x, diagonal, k=0, align="RIGHT_LEFT"): # pylint: disable=redefined-outer-name
|
||||
r"""
|
||||
Returns a batched matrix tensor with new batched diagonal values.
|
||||
Given x and diagonal, this operation returns a tensor with the same shape and values as x, except for the specified
|
||||
|
@ -4651,6 +4651,57 @@ def diag(input_x):
|
|||
return diag_(input_x)
|
||||
|
||||
|
||||
def diagflat(x, offset=0):
|
||||
r"""
|
||||
Creates a two-dimensional Tensor with the flattened input as a diagonal.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input Tensor, which is flattened and set as the diagonal of the output.
|
||||
offset (int, optional): `offset` controls which diagonal to consider. Default: 0.
|
||||
|
||||
- When `offset` is zero, the diagonal chosen is the main diagonal.
|
||||
- When `offset` is greater than zero, the diagonal chosen is above the main diagonal.
|
||||
- When `offset` is less than zero, the diagonal chosen is below the main diagonal.
|
||||
|
||||
Returns:
|
||||
The 2-D Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a tensor.
|
||||
TypeError: If `offset` is not an integer.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([1, 2], mindspore.float32)
|
||||
>>> output = ops.diagflat(x, 1)
|
||||
>>> print(output)
|
||||
[[0. 1. 0.]
|
||||
[0. 0. 2.]
|
||||
[0. 0. 0.]]
|
||||
"""
|
||||
if not isinstance(x, (Tensor, Tensor_)):
|
||||
raise TypeError(f"For diagflat, the input x must be tensor, but got {type(x)}")
|
||||
if not isinstance(offset, int):
|
||||
raise TypeError(f"For diagflat, the offset must be int, but got {type(offset)}")
|
||||
offset_abs = abs(offset)
|
||||
if x.size == 0:
|
||||
return zeros((offset_abs, offset_abs), x.dtype)
|
||||
x = x.ravel()
|
||||
res = diag(x)
|
||||
if offset != 0:
|
||||
pad_y = zeros((x.size + offset_abs, offset_abs), x.dtype)
|
||||
pad_x = zeros((offset_abs, x.size), x.dtype)
|
||||
if offset < 0:
|
||||
res = cat((pad_x, res), axis=0)
|
||||
res = cat((res, pad_y), axis=1)
|
||||
else:
|
||||
res = cat((res, pad_x), axis=0)
|
||||
res = cat((pad_y, res), axis=1)
|
||||
return res
|
||||
|
||||
|
||||
def col2im(input_x, output_size, kernel_size, dilation, padding_value, stride):
|
||||
"""
|
||||
Combines an array of sliding local blocks into a large containing tensor.
|
||||
|
@ -5965,6 +6016,7 @@ __all__ = [
|
|||
'matrix_diag_part',
|
||||
'matrix_set_diag',
|
||||
'diag',
|
||||
'diagflat',
|
||||
'meshgrid',
|
||||
'affine_grid',
|
||||
'meshgrid',
|
||||
|
|
|
@ -4218,7 +4218,7 @@ def hardtanh(x, min_val=-1.0, max_val=1.0):
|
|||
|
||||
def huber_loss(x, target, reduction='mean', delta=1.0):
|
||||
r"""
|
||||
huber_loss calculate the error between the predicted value and the target value.
|
||||
huber_loss calculates the error between the predicted value and the target value.
|
||||
It has the advantages of both l1_loss and mse_loss.
|
||||
|
||||
Assuming that the :math:`x` and :math:`y` are 1-D Tensor, length :math:`N`, the reduction parameter is set to "none"
|
||||
|
|
|
@ -234,6 +234,7 @@ tensor_operator_registry.register('hypot', hypot)
|
|||
tensor_operator_registry.register('soft_shrink', P.SoftShrink)
|
||||
tensor_operator_registry.register('svd', linalg_ops.Svd)
|
||||
tensor_operator_registry.register('diag', P.Diag)
|
||||
tensor_operator_registry.register('diagflat', diagflat)
|
||||
tensor_operator_registry.register('unique_consecutive', UniqueConsecutive)
|
||||
tensor_operator_registry.register('unique_with_pad', P.UniqueWithPad)
|
||||
tensor_operator_registry.register('inplace_update', P.InplaceUpdate)
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
out1 = ops.diagflat(x, -1)
|
||||
out2 = ops.diagflat(x, 1)
|
||||
return out1, out2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_ops_diagflat(mode):
|
||||
"""
|
||||
Feature: ops.diagflat
|
||||
Description: Verify the result of diagflat
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = ms.Tensor([-0.5, 0.5, 3], ms.float32)
|
||||
net = Net()
|
||||
output1, output2 = net(x)
|
||||
expect_output1 = [[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
|
||||
[-5.00000000e-01, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
|
||||
[0.00000000e+00, 5.00000000e-01, 0.00000000e+00, 0.00000000e+00],
|
||||
[0.00000000e+00, 0.00000000e+00, 3.00000000e+00, 0.00000000e+00]]
|
||||
expect_output2 = [[0.00000000e+00, -5.00000000e-01, 0.00000000e+00, 0.00000000e+00],
|
||||
[0.00000000e+00, 0.00000000e+00, 5.00000000e-01, 0.00000000e+00],
|
||||
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 3.00000000e+00],
|
||||
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]
|
||||
assert np.allclose(output1.asnumpy(), expect_output1)
|
||||
assert np.allclose(output2.asnumpy(), expect_output2)
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x.diagflat(1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_tensor_diagflat(mode):
|
||||
"""
|
||||
Feature: tensor.diagflat
|
||||
Description: Verify the result of diagflat
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = ms.Tensor([-0.5, 0.5, 3], ms.float32)
|
||||
net = Net()
|
||||
output = net(x)
|
||||
expect_output = [[0.00000000e+00, -5.00000000e-01, 0.00000000e+00, 0.00000000e+00],
|
||||
[0.00000000e+00, 0.00000000e+00, 5.00000000e-01, 0.00000000e+00],
|
||||
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 3.00000000e+00],
|
||||
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
Loading…
Reference in New Issue