forked from mindspore-Ecosystem/mindspore
!46567 [api] add tensor.sin & tensor.sinc
Merge pull request !46567 from DavidFFFan/api
This commit is contained in:
commit
77677f153f
|
@ -257,6 +257,7 @@ mindspore.ops
|
|||
mindspore.ops.round
|
||||
mindspore.ops.rsqrt
|
||||
mindspore.ops.sin
|
||||
mindspore.ops.sinc
|
||||
mindspore.ops.sinh
|
||||
mindspore.ops.sqrt
|
||||
mindspore.ops.square
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.cos
|
||||
====================
|
||||
|
||||
.. py:method:: mindspore.Tensor.cos()
|
||||
|
||||
详情请参考 :func:`mindspore.ops.cos`。
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.sin
|
||||
====================
|
||||
|
||||
.. py:method:: mindspore.Tensor.sin()
|
||||
|
||||
详情请参考 :func:`mindspore.ops.sin`。
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.sinc
|
||||
=====================
|
||||
|
||||
.. py:method:: mindspore.Tensor.sinc()
|
||||
|
||||
详情请参考 :func:`mindspore.ops.sinc`。
|
|
@ -77,6 +77,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.conj
|
||||
mindspore.Tensor.copy
|
||||
mindspore.Tensor.copysign
|
||||
mindspore.Tensor.cos
|
||||
mindspore.Tensor.cosh
|
||||
mindspore.Tensor.cross
|
||||
mindspore.Tensor.cummax
|
||||
|
@ -233,6 +234,8 @@ mindspore.Tensor
|
|||
mindspore.Tensor.shape
|
||||
mindspore.Tensor.short
|
||||
mindspore.Tensor.sigmoid
|
||||
mindspore.Tensor.sin
|
||||
mindspore.Tensor.sinc
|
||||
mindspore.Tensor.size
|
||||
mindspore.Tensor.soft_shrink
|
||||
mindspore.Tensor.split
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
mindspore.ops.sinc
|
||||
==================
|
||||
|
||||
.. py:function:: mindspore.ops.sinc(x)
|
||||
|
||||
按照以下公式逐元素计算输入Tensor的数学正弦函数。
|
||||
|
||||
.. math::
|
||||
out_i = \begin{cases} \frac{sin(\pi x_i)}{x_i} & x_i\neq 0\\
|
||||
1 & x_i=0 \end{cases}
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - `x` 的shape为 :math:`(x_1, x_2, ..., x_R)`。
|
||||
|
||||
返回:
|
||||
Tensor,shape与 `x` 相同。
|
||||
当输入类型在[uint8, uint8, uint16, int16, uint32, int32, uint64, int64, bool]时,返回值类型为float32。
|
||||
否则,返回值类型与输入类型相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `x` 不是Tensor。
|
|
@ -83,6 +83,7 @@
|
|||
mindspore.Tensor.conj
|
||||
mindspore.Tensor.copy
|
||||
mindspore.Tensor.copysign
|
||||
mindspore.Tensor.cos
|
||||
mindspore.Tensor.cosh
|
||||
mindspore.Tensor.cross
|
||||
mindspore.Tensor.cummax
|
||||
|
@ -239,6 +240,8 @@
|
|||
mindspore.Tensor.shape
|
||||
mindspore.Tensor.short
|
||||
mindspore.Tensor.sigmoid
|
||||
mindspore.Tensor.sin
|
||||
mindspore.Tensor.sinc
|
||||
mindspore.Tensor.size
|
||||
mindspore.Tensor.soft_shrink
|
||||
mindspore.Tensor.split
|
||||
|
|
|
@ -257,6 +257,7 @@ Element-by-Element Operations
|
|||
mindspore.ops.round
|
||||
mindspore.ops.rsqrt
|
||||
mindspore.ops.sin
|
||||
mindspore.ops.sinc
|
||||
mindspore.ops.sinh
|
||||
mindspore.ops.sqrt
|
||||
mindspore.ops.square
|
||||
|
|
|
@ -376,6 +376,8 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"addmv", std::string("addmv")}, // addmv()
|
||||
{"adjoint", std::string("adjoint")}, // adjoint()
|
||||
{"arccosh", std::string("acosh")}, // arccosh()
|
||||
{"sin", std::string("sin")}, // sin()
|
||||
{"sinc", std::string("sinc")}, // sinc()
|
||||
{"arcsin", std::string("asin")}, // arcsin()
|
||||
{"arctan", std::string("atan")}, // arctan()
|
||||
{"arctan2", std::string("atan2")}, // arctan2()
|
||||
|
|
|
@ -3587,6 +3587,20 @@ def isfinite(x):
|
|||
return F.isfinite(x)
|
||||
|
||||
|
||||
def sin(x):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.sin`.
|
||||
"""
|
||||
return F.sin(x)
|
||||
|
||||
|
||||
def sinc(x):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.sinc`.
|
||||
"""
|
||||
return F.sinc(x)
|
||||
|
||||
|
||||
def cos(x):
|
||||
r"""
|
||||
Computes cosine of input element-wise.
|
||||
|
|
|
@ -1315,27 +1315,7 @@ class Tensor(Tensor_):
|
|||
|
||||
def cos(self):
|
||||
r"""
|
||||
Computes cosine of input element-wise.
|
||||
|
||||
.. math::
|
||||
out_i = cos(x_i)
|
||||
|
||||
.. warning::
|
||||
Currently support Float16, Float32 data type. If use Float64, there may
|
||||
be a problem of missing precision.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape as `x`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor
|
||||
>>> a = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32)
|
||||
>>> output = a.cos()
|
||||
>>> print(output)
|
||||
[0.971338 0.6748758 0.95233357 0.9959527]
|
||||
For details, please refer to :func:`mindspore.ops.cos`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('cos')(self)
|
||||
|
@ -4559,6 +4539,20 @@ class Tensor(Tensor_):
|
|||
_dtype = self.dtype if dtype is None else dtype
|
||||
return tensor_operator_registry.get('ones')(size, _dtype)
|
||||
|
||||
def sin(self):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.sin`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('sin')(self)
|
||||
|
||||
def sinc(self):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.sinc`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('sinc')(self)
|
||||
|
||||
def sinh(self):
|
||||
r"""
|
||||
Computes hyperbolic sine of the input element-wise.
|
||||
|
|
|
@ -244,6 +244,7 @@ from .math_func import (
|
|||
logsumexp,
|
||||
outer,
|
||||
sin,
|
||||
sinc,
|
||||
cos,
|
||||
tan,
|
||||
asin,
|
||||
|
|
|
@ -1669,7 +1669,8 @@ def sinc(x):
|
|||
|
||||
.. math::
|
||||
|
||||
out_i = sinc(x_i)
|
||||
out_i = \begin{cases} \frac{sin(\pi x_i)}{x_i} & x_i\neq 0\\
|
||||
1 & x_i=0 \end{cases}
|
||||
|
||||
Args:
|
||||
x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
|
@ -1683,7 +1684,7 @@ def sinc(x):
|
|||
TypeError: If `x` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([0.62, 0.28, 0.43, 0.62]), mindspore.float32)
|
||||
|
@ -9150,6 +9151,7 @@ __all__ = [
|
|||
'acos',
|
||||
'arccos',
|
||||
'atan',
|
||||
'sinc',
|
||||
'sinh',
|
||||
'cosh',
|
||||
'tanh',
|
||||
|
|
|
@ -147,6 +147,8 @@ tensor_operator_registry.register('cos', cos)
|
|||
tensor_operator_registry.register('acosh', acosh)
|
||||
tensor_operator_registry.register('cosh', P.Cosh)
|
||||
tensor_operator_registry.register('asin', asin)
|
||||
tensor_operator_registry.register('sin', sin)
|
||||
tensor_operator_registry.register('sinc', sinc)
|
||||
tensor_operator_registry.register('pow', P.Pow)
|
||||
tensor_operator_registry.register('negative', neg)
|
||||
tensor_operator_registry.register('amin', amin)
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from math import pi
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x.sin()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_tensor_sin(mode):
|
||||
"""
|
||||
Feature: tensor.sin
|
||||
Description: Verify the result of sin
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor(np.array([-pi/6, pi/6, pi*10]), ms.float32)
|
||||
net = Net()
|
||||
output = net(x)
|
||||
expect_output = np.array([-0.5, 0.5, 0], dtype=np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect_output, rtol=5e-3, atol=1e-4)
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x.sinc()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_tensor_sinc(mode):
|
||||
"""
|
||||
Feature: tensor.sinc
|
||||
Description: Verify the result of sinc
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor(np.array([0.62, 0.28, 0.43, 0]), ms.float32)
|
||||
net = Net()
|
||||
output = net(x)
|
||||
expect_output = np.array([0.47735006, 0.8759357, 0.7224278, 1.], dtype=np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect_output, rtol=5e-3, atol=1e-4)
|
Loading…
Reference in New Issue