!46058 support rsqrt reciprocal real

Merge pull request !46058 from 冯一航/support_rsqrt_reciprocal_real
This commit is contained in:
i-robot 2022-11-29 09:00:43 +00:00 committed by Gitee
commit f7b2d406d1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
26 changed files with 633 additions and 4 deletions

View File

@ -230,9 +230,12 @@ mindspore.ops
mindspore.ops.positive
mindspore.ops.pow
mindspore.ops.rad2deg
mindspore.ops.real
mindspore.ops.reciprocal
mindspore.ops.remainder
mindspore.ops.roll
mindspore.ops.round
mindspore.ops.rsqrt
mindspore.ops.sin
mindspore.ops.sinh
mindspore.ops.sqrt

View File

@ -0,0 +1,6 @@
mindspore.Tensor.real
======================
.. py:function:: mindspore.Tensor.real(x)
详情请参考 :func:`mindspore.ops.real`

View File

@ -0,0 +1,6 @@
mindspore.Tensor.reciprocal
============================
.. py:function:: mindspore.Tensor.reciprocal(x)
详情请参考 :func:`mindspore.Tensor.reciprocal`

View File

@ -0,0 +1,7 @@
mindspore.Tensor.rsqrt
=======================
.. py:function:: mindspore.Tensor.rsqrt(x)
详情请参考 :func:`mindspore.ops.rsqrt`

View File

@ -195,6 +195,8 @@ mindspore.Tensor
mindspore.Tensor.rad2deg
mindspore.Tensor.random_categorical
mindspore.Tensor.ravel
mindspore.Tensor.real
mindspore.Tensor.reciprocal
mindspore.Tensor.remainder
mindspore.Tensor.renorm
mindspore.Tensor.repeat
@ -207,6 +209,7 @@ mindspore.Tensor
mindspore.Tensor.round
mindspore.Tensor.roll
mindspore.Tensor.rot90
mindspore.Tensor.rsqrt
mindspore.Tensor.scatter_add
mindspore.Tensor.scatter_div
mindspore.Tensor.scatter_max

View File

@ -0,0 +1,15 @@
mindspore.ops.real
===================
.. py:function:: mindspore.ops.real(x)
返回输入Tensor的实数部分。如果输入是实数则返回值与输入值相同。
参数:
- **x** (Tensor) - 要计算的输入Tensor。
返回:
Tensorshape与输入 `x` 相同。
异常:
- **TypeError** - 如果 `x` 不是Tensor。

View File

@ -0,0 +1,18 @@
mindspore.ops.reciprocal
=========================
.. py:function:: mindspore.ops.reciprocal(x)
返回输入Tensor的倒数。
.. math::
out_{i} = \frac{1}{x_{i}}
参数:
- **x** (Tensor) - 输入Tensor。shape :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度。
返回:
Tensorshape与 `x` 相同。
异常:
- **TypeError** - 如果 `x` 不是Tensor。

View File

@ -0,0 +1,19 @@
mindspore.ops.rsqrt
====================
.. py:function:: mindspore.ops.rsqrt(x)
逐元素计算输入Tensor元素的平方根倒数。
.. math::
out_{i} = \frac{1}{\sqrt{x_{i}}}
参数:
- **x** (Tensor) - rsqrt的输入Tensor其rank需要在[0, 7]范围内且每个元素都为非负若某个元素为负计算结果为nan。
返回:
Tensor具有与 `x` 相同的shape。
异常:
- **TypeError** - 如果 `x` 不是Tensor。

View File

@ -201,6 +201,8 @@
mindspore.Tensor.rad2deg
mindspore.Tensor.random_categorical
mindspore.Tensor.ravel
mindspore.Tensor.real
mindspore.Tensor.reciprocal
mindspore.Tensor.remainder
mindspore.Tensor.renorm
mindspore.Tensor.repeat
@ -213,6 +215,7 @@
mindspore.Tensor.round
mindspore.Tensor.roll
mindspore.Tensor.rot90
mindspore.Tensor.rsqrt
mindspore.Tensor.scatter_add
mindspore.Tensor.scatter_div
mindspore.Tensor.scatter_max

View File

@ -231,9 +231,12 @@ Element-by-Element Operations
mindspore.ops.positive
mindspore.ops.pow
mindspore.ops.rad2deg
mindspore.ops.real
mindspore.ops.reciprocal
mindspore.ops.remainder
mindspore.ops.roll
mindspore.ops.round
mindspore.ops.rsqrt
mindspore.ops.sin
mindspore.ops.sinh
mindspore.ops.sqrt

View File

@ -308,6 +308,9 @@ BuiltInTypeMap &GetMethodMap() {
{"unsorted_segment_max", std::string("unsorted_segment_max")}, // P.UnsortedSegmentMax()
{"unsorted_segment_prod", std::string("unsorted_segment_prod")}, // P.UnsortedSegmentProd()
{"renorm", std::string("renorm")}, // renorm()
{"real", std::string("real")}, // real()
{"reciprocal", std::string("reciprocal")}, // reciprocal()
{"rsqrt", std::string("rsqrt")}, // rsqrt()
{"trace", std::string("trace")}, // P.Eye()
{"var", std::string("var")}, // P.ReduceSum
{"std", std::string("std")}, // P.ReduceSum

View File

@ -2703,6 +2703,27 @@ def exp(x):
return F.exp(x)
def real(x):
r"""
For details, please refer to :func:`mindspore.ops.real`.
"""
return F.real(x)
def rsqrt(x):
r"""
For details, please refer to :func:`mindspore.ops.rsqrt`.
"""
return F.rsqrt(x)
def reciprocal(x):
r"""
For details, please refer to :func:`mindspore.ops.reciprocal`.
"""
return F.reciprocal(x)
def sqrt(x):
"""Returns sqrt of a tensor element-wise."""
return F.sqrt(x)

View File

@ -1214,6 +1214,27 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('exp')()(self)
def real(self):
r"""
For details, please refer to :func:`mindspore.ops.real`.
"""
self._init_check()
return tensor_operator_registry.get('real')(self)
def rsqrt(self):
r"""
For details, please refer to :func:`mindspore.ops.rsqrt`.
"""
self._init_check()
return tensor_operator_registry.get('rsqrt')(self)
def reciprocal(self):
r"""
For details, please refer to :func:`mindspore.ops.reciprocal`.
"""
self._init_check()
return tensor_operator_registry.get('reciprocal')(self)
def sqrt(self):
"""
For details, please refer to :func:`mindspore.ops.sqrt`.
@ -3372,7 +3393,6 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('heaviside')(self, values)
def hypot(self, other):
r"""
For details, please refer to :func:`mindspore.ops.hypot`.

View File

@ -158,6 +158,9 @@ from .math_func import (
tensor_ge,
ge,
tensor_sub,
rsqrt,
reciprocal,
real,
sub,
subtract,
sqrt,

View File

@ -4927,9 +4927,9 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
(16, 16, 8, 8)
"""
kernel_size = _check_fold_param(kernel_size, "kernel_size")
dilation = _check_fold_param(kernel_size, "dilation")
padding = _check_fold_param(kernel_size, "padding")
stride = _check_fold_param(kernel_size, "stride")
dilation = _check_fold_param(dilation, "dilation")
padding = _check_fold_param(padding, "padding")
stride = _check_fold_param(stride, "stride")
fold_op = _get_cache_prim(Col2Im)(kernel_size, dilation, padding, stride)
return fold_op(input, output_size)

View File

@ -3813,6 +3813,100 @@ def std(input_x, axis=(), unbiased=True, keep_dims=False):
return output
def real(x):
r"""
Returns a Tensor that is the real part of the input.
If input is real, it is returned unchanged.
Args:
x (Tensor): The input tensor to compute to.
Returns:
Tensor, the shape is the same as the `x`.
Raises:
TypeError: If `x` is not a Tensor.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> import mindspore as ms
>>> import mindspore.ops as ops
>>> import numpy as np
>>> x = ms.Tensor(np.asarray(np.complex(1.3+0.4j)), ms.complex64)
>>> output = ops.real(x)
>>> print(output)
1.3
"""
return _get_cache_prim(ops.Real)()(x)
def reciprocal(x):
r"""
Returns reciprocal of a tensor element-wise.
.. math::
out_{i} = \frac{1}{x_{i}}
Args:
x (Tensor): The input tensor.
:math:`(N,*)` where :math:`*` means, any number of additional dimensions.
Returns:
Tensor, has the same shape as the `x`.
Raises:
TypeError: If `x` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore as ms
>>> import mindspore.ops as ops
>>> import numpy as np
>>> x = ms.Tensor(np.array([1.0, 2.0, 4.0]), ms.float32)
>>> output = ops.reciprocal(x)
>>> print(output)
[1. 0.5 0.25]
"""
return _get_cache_prim(ops.Reciprocal)()(x)
def rsqrt(x):
r"""
Computes reciprocal of square root of input tensor element-wise.
.. math::
out_{i} = \frac{1}{\sqrt{x_{i}}}
Args:
x (Tensor): The input of rsqrt. Its rank must be in [0, 7] inclusive and
each element must be a non-negative number, if an element is negative, the calculation result is nan.
Returns:
Tensor, has the same shape and dtype as the `x`.
Raises:
TypeError: If `x` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore as ms
>>> import mindspore.ops as ops
>>> x = ms.Tensor([-0.0370, 0.2970, 1.5420, -0.9105])
>>> output = ops.rsqrt(x)
>>> print(output)
[ nan 1.8349396 0.80530024 nan]
"""
return _get_cache_prim(ops.Rsqrt)()(x)
def sqrt(x):
"""
Returns sqrt of a tensor element-wise.
@ -7582,6 +7676,9 @@ __all__ = [
'logit',
'logsumexp',
'ldexp',
'rsqrt',
'reciprocal',
'real',
'sqrt',
'square',
'sin',

View File

@ -132,6 +132,9 @@ tensor_operator_registry.register('any', P.ReduceAny)
tensor_operator_registry.register('atan2', atan2)
tensor_operator_registry.register('abs', P.Abs)
tensor_operator_registry.register('baddbmm', baddbmm)
tensor_operator_registry.register('real', real)
tensor_operator_registry.register('reciprocal', reciprocal)
tensor_operator_registry.register('rsqrt', rsqrt)
tensor_operator_registry.register('sqrt', sqrt)
tensor_operator_registry.register('square', square)
tensor_operator_registry.register('sub', sub)

View File

@ -0,0 +1,47 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x):
output = ops.real(x)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_real_normal(mode):
"""
Feature: real
Description: Verify the result of real
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor(np.asarray(np.complex(1.3+0.4j)), ms.complex64)
out = net(x)
expect_out = np.array(1.3)
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,49 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x):
output = ops.reciprocal(x)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_reciprocal_normal(mode):
"""
Feature: reciprocal
Description: Verify the result of reciprocal
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor(np.array([1.0, 2.0, 4.0]), ms.float32)
out = net(x)
expect_out = np.array([1., 0.5, 0.25])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,49 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, x):
output = ops.rsqrt(x)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_rsqrt_normal(mode):
"""
Feature: rsqrt
Description: Verify the result of rsqrt
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor([0.0370, 0.2970, 1.5420, 0.9105])
out = net(x)
expect_out = np.array([5.1987524, 1.8349396, 0.80530024, 1.047997])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,46 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x):
output = x.real()
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_real_normal(mode):
"""
Feature: real
Description: Verify the result of real
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor(np.asarray(np.complex(1.3+0.4j)), ms.complex64)
out = net(x)
expect_out = np.array(1.3)
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,48 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x):
output = x.reciprocal()
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_reciprocal_normal(mode):
"""
Feature: reciprocal
Description: Verify the result of reciprocal
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor(np.array([1.0, 2.0, 4.0]), ms.float32)
out = net(x)
expect_out = np.array([1., 0.5, 0.25])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,48 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Net(nn.Cell):
def construct(self, x):
output = x.rsqrt()
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_rsqrt_normal(mode):
"""
Feature: rsqrt
Description: Verify the result of rsqrt
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = ms.Tensor([0.0370, 0.2970, 1.5420, 0.9105])
out = net(x)
expect_out = np.array([5.1987524, 1.8349396, 0.80530024, 1.047997])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,38 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.api import _cell_graph_executor
class Net(nn.Cell):
def construct(self, x):
output = ops.real(x)
return output
def test_real_normal():
"""
Feature: Test real
Description: Test the functionality of real
Expectation: Success
"""
net = Net()
x = ms.Tensor(np.asarray(np.complex(1.3 + 0.4j)), ms.complex64)
_cell_graph_executor.compile(net, x)

View File

@ -0,0 +1,38 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.api import _cell_graph_executor
class Net(nn.Cell):
def construct(self, x):
output = ops.reciprocal(x)
return output
def test_reciprocal_normal():
"""
Feature: Test reciprocal
Description: Test the functionality of reciprocal
Expectation: Success
"""
net = Net()
x = ms.Tensor(np.array([1.0, 2.0, 4.0]), ms.float32)
_cell_graph_executor.compile(net, x)

View File

@ -0,0 +1,36 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.api import _cell_graph_executor
class Net(nn.Cell):
def construct(self, x):
output = ops.rsqrt(x)
return output
def test_rsqrt_normal():
"""
Feature: Test rsqrt
Description: Test the functionality of rsqrt
Expectation: Success
"""
net = Net()
x = ms.Tensor([-0.0370, 0.2970, 1.5420, -0.9105])
_cell_graph_executor.compile(net, x)