!46318 Add bitwise left & right shift
Merge pull request !46318 from shaojunsong/feature/api1130
This commit is contained in:
commit
dcf3689c1a
|
@ -182,7 +182,9 @@ mindspore.ops
|
|||
mindspore.ops.bessel_y0
|
||||
mindspore.ops.bessel_y1
|
||||
mindspore.ops.bitwise_and
|
||||
mindspore.ops.bitwise_left_shift
|
||||
mindspore.ops.bitwise_or
|
||||
mindspore.ops.bitwise_right_shift
|
||||
mindspore.ops.bitwise_xor
|
||||
mindspore.ops.ceil
|
||||
mindspore.ops.clamp
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.bitwise_left_shift
|
||||
====================================
|
||||
|
||||
.. py:method:: mindspore.Tensor.bitwise_left_shift(other)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.bitwise_left_shift`。
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.bitwise_right_shift
|
||||
====================================
|
||||
|
||||
.. py:method:: mindspore.Tensor.bitwise_right_shift(other)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.bitwise_right_shift`。
|
|
@ -59,7 +59,9 @@ mindspore.Tensor
|
|||
mindspore.Tensor.baddbmm
|
||||
mindspore.Tensor.bernoulli
|
||||
mindspore.Tensor.bitwise_and
|
||||
mindspore.Tensor.bitwise_left_shift
|
||||
mindspore.Tensor.bitwise_or
|
||||
mindspore.Tensor.bitwise_right_shift
|
||||
mindspore.Tensor.bitwise_xor
|
||||
mindspore.Tensor.bmm
|
||||
mindspore.Tensor.bool
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
mindspore.ops.bitwise_left_shift
|
||||
=================================
|
||||
|
||||
.. py:function:: mindspore.ops.bitwise_left_shift(x, other)
|
||||
|
||||
对输入 `x` 进行左移 `other` 位运算。
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{aligned}
|
||||
&out_{i} =x_{i} << other_{i}
|
||||
\end{aligned}
|
||||
|
||||
参数:
|
||||
- **x** (Union[Tensor, Scalar]) - 被左移的输入。
|
||||
- **other** (Union[Tensor, Scalar]) - 左移的位数。
|
||||
|
||||
返回:
|
||||
Tensor,左移位运算后的结果。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 或 `other` 都不是Tensor。
|
||||
- **TypeError** - `x` 或 `other` 不是int、int类型的Tensor或uint类型的Tensor。
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
mindspore.ops.bitwise_right_shift
|
||||
=================================
|
||||
|
||||
.. py:function:: mindspore.ops.bitwise_right_shift(x, other)
|
||||
|
||||
对输入 `x` 进行右移 `other` 位运算。
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{aligned}
|
||||
&out_{i} =x_{i} >> other_{i}
|
||||
\end{aligned}
|
||||
|
||||
参数:
|
||||
- **x** (Union[Tensor, Scalar]) - 被右移的输入。
|
||||
- **other** (Union[Tensor, Scalar]) - 右移的位数。
|
||||
|
||||
返回:
|
||||
Tensor,右移位运算后的结果。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 或 `other` 都不是Tensor。
|
||||
- **TypeError** - `x` 或 `other` 不是int、int类型的Tensor或uint类型的Tensor。
|
||||
|
|
@ -65,7 +65,9 @@
|
|||
mindspore.Tensor.baddbmm
|
||||
mindspore.Tensor.bernoulli
|
||||
mindspore.Tensor.bitwise_and
|
||||
mindspore.Tensor.bitwise_left_shift
|
||||
mindspore.Tensor.bitwise_or
|
||||
mindspore.Tensor.bitwise_right_shift
|
||||
mindspore.Tensor.bitwise_xor
|
||||
mindspore.Tensor.bmm
|
||||
mindspore.Tensor.bool
|
||||
|
|
|
@ -184,6 +184,8 @@ Element-by-Element Operations
|
|||
mindspore.ops.bitwise_and
|
||||
mindspore.ops.bitwise_or
|
||||
mindspore.ops.bitwise_xor
|
||||
mindspore.ops.bitwise_left_shift
|
||||
mindspore.ops.bitwise_right_shift
|
||||
mindspore.ops.ceil
|
||||
mindspore.ops.clip
|
||||
mindspore.ops.clamp
|
||||
|
|
|
@ -216,6 +216,8 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"bitwise_and", std::string("bitwise_and")}, // P.BitwiseAnd()
|
||||
{"bitwise_or", std::string("bitwise_or")}, // P.BitwiseOr()
|
||||
{"bitwise_xor", std::string("bitwise_xor")}, // P.BitwiseXor()
|
||||
{"bitwise_left_shift", std::string("bitwise_left_shift")}, // bitwise_left_shift
|
||||
{"bitwise_right_shift", std::string("bitwise_right_shift")}, // bitwise_right_shift
|
||||
{"tan", std::string("tan")}, // P.Tan()
|
||||
{"ger", std::string("ger")}, // P.Ger()
|
||||
{"ravel", std::string("ravel")}, // P.reshape(,(-1,))
|
||||
|
|
|
@ -2719,6 +2719,16 @@ def bitwise_xor(x, y):
|
|||
return F.bitwise_xor(x, y)
|
||||
|
||||
|
||||
def bitwise_left_shift(x, y):
|
||||
"""Returns bitwise left shift of `x` by `other` bits."""
|
||||
return F.bitwise_left_shift(x, y)
|
||||
|
||||
|
||||
def bitwise_right_shift(x, y):
|
||||
"""Returns bitwise right shift of `x` by `other` bits."""
|
||||
return F.bitwise_right_shift(x, y)
|
||||
|
||||
|
||||
def exp(x):
|
||||
"""Returns exponential of a tensor element-wise."""
|
||||
return F.exp(x)
|
||||
|
|
|
@ -1140,6 +1140,22 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('bitwise_xor')(self, x)
|
||||
|
||||
def bitwise_left_shift(self, other):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.bitwise_left_shift`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('bitwise_left_shift')(self, other)
|
||||
|
||||
def bitwise_right_shift(self, other):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.bitwise_left_shift`.
|
||||
"""
|
||||
self._init_check()
|
||||
_cast = tensor_operator_registry.get('cast')
|
||||
other = _cast(other, self.dtype)
|
||||
return tensor_operator_registry.get('bitwise_right_shift')(self, other)
|
||||
|
||||
def scatter_mul(self, indices, updates):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.scatter_mul`.
|
||||
|
|
|
@ -261,6 +261,8 @@ from .math_func import (
|
|||
bitwise_and,
|
||||
bitwise_or,
|
||||
bitwise_xor,
|
||||
bitwise_left_shift,
|
||||
bitwise_right_shift,
|
||||
erf,
|
||||
erfc,
|
||||
cdist,
|
||||
|
|
|
@ -2258,6 +2258,115 @@ def bitwise_xor(x, y):
|
|||
return bitwise_xor_(x, y)
|
||||
|
||||
|
||||
def bitwise_left_shift(x, other):
|
||||
r"""
|
||||
Calculates the left arithmetic shift of `x` by `other` bits.
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{aligned}
|
||||
&out_{i} =x_{i} << other_{i}
|
||||
\end{aligned}
|
||||
|
||||
Args:
|
||||
x (Tensor or Scalar): The input to be left shifted.
|
||||
other(Tensor or Scalar): The number of bit to be applied on left arithmetic shift.
|
||||
|
||||
Returns:
|
||||
Tensor, the result after bitwise left shift.
|
||||
|
||||
Raises:
|
||||
TypeError: If neither `x` nor `other` is a tensor.
|
||||
TypeError: If either `x` or `other` is not an int or a tensor of dtype: int or uint.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([1024, 2]), mindspore.int16)
|
||||
>>> y = Tensor(np.array([2]), mindspore.int16)
|
||||
>>> output = ops.bitwise_left_shift(x, y)
|
||||
>>> print(output)
|
||||
[4096 8]
|
||||
"""
|
||||
if isinstance(x, numbers.Number) and isinstance(other, numbers.Number):
|
||||
raise TypeError(f"For 'bitwise_left_shift', at least one of the inputs should be a Tensor.")
|
||||
|
||||
cast = ops.Cast()
|
||||
white_list = [mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
||||
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64]
|
||||
if isinstance(x, numbers.Number):
|
||||
_dtype = other.dtype
|
||||
if not isinstance(x, int):
|
||||
raise TypeError(f"For 'bitwise_left_shift', 'x' must be an integer, but got x:{type(x)}.")
|
||||
if _dtype not in white_list:
|
||||
raise TypeError(f"For 'bitwise_left_shift', 'other' must be a Tensor of int or uint, but got {_dtype}.")
|
||||
x = cast(x, other.dtype)
|
||||
elif isinstance(other, numbers.Number):
|
||||
_dtype = x.dtype
|
||||
if not isinstance(other, int):
|
||||
raise TypeError(f"For 'bitwise_left_shift', 'other' must be an integer, but got other:{type(other)}.")
|
||||
if _dtype not in white_list:
|
||||
raise TypeError(f"For 'bitwise_left_shift', 'x' must be a Tensor of int or uint, but got {_dtype}.")
|
||||
other = cast(other, x.dtype)
|
||||
ls = ops.LeftShift()
|
||||
return ls(x, other)
|
||||
|
||||
|
||||
def bitwise_right_shift(x, other):
|
||||
r"""
|
||||
Calculates the right arithmetic shift of `x` by `other` bits.
|
||||
|
||||
.. math::
|
||||
|
||||
\begin{aligned}
|
||||
&out_{i} =x_{i} >> y_{i}
|
||||
\end{aligned}
|
||||
|
||||
Args:
|
||||
x (Tensor or Scalar): The input to be right shifted.
|
||||
other(Tensor or Scalar): The number of bit to be applied on right arithmetic shift.
|
||||
|
||||
Returns:
|
||||
Tensor, the result after bitwise right shift.
|
||||
|
||||
Raises:
|
||||
TypeError: If neither `x` nor `other` is a tensor.
|
||||
TypeError: If either `x` or `other` is not an int or a tensor of dtype: int or uint.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([1024, 2]), mindspore.int16)
|
||||
>>> y = Tensor(np.array([2]), mindspore.int16)
|
||||
>>> output = ops.bitwise_right_shift(x, y)
|
||||
>>> print(output)
|
||||
[256 0]
|
||||
"""
|
||||
if isinstance(x, numbers.Number) and isinstance(other, numbers.Number):
|
||||
raise TypeError(f"For 'bitwise_left_shift', at least one of the inputs should be a Tensor.")
|
||||
cast = ops.Cast()
|
||||
white_list = [mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
||||
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64]
|
||||
if isinstance(x, numbers.Number):
|
||||
_dtype = other.dtype
|
||||
if not isinstance(x, int):
|
||||
raise TypeError(f"For 'bitwise_left_shift', 'x' must be an integer, but got x:{type(x)}.")
|
||||
if _dtype not in white_list:
|
||||
raise TypeError(f"For 'bitwise_left_shift', 'other' must be a Tensor of int or uint, but got {_dtype}.")
|
||||
x = cast(x, other.dtype)
|
||||
elif isinstance(other, numbers.Number):
|
||||
_dtype = x.dtype
|
||||
if not isinstance(other, int):
|
||||
raise TypeError(f"For 'bitwise_left_shift', 'other' must be an integer, but got other:{type(other)}.")
|
||||
if _dtype not in white_list:
|
||||
raise TypeError(f"For 'bitwise_left_shift', 'x' must be a Tensor of int or uint, but got {_dtype}.")
|
||||
other = cast(other, x.dtype)
|
||||
rs = ops.RightShift()
|
||||
return rs(x, other)
|
||||
|
||||
|
||||
def inv(x):
|
||||
r"""
|
||||
Computes Reciprocal of input tensor element-wise.
|
||||
|
@ -8102,6 +8211,8 @@ __all__ = [
|
|||
'bitwise_and',
|
||||
'bitwise_or',
|
||||
'bitwise_xor',
|
||||
'bitwise_left_shift',
|
||||
'bitwise_right_shift',
|
||||
'inv',
|
||||
'inverse',
|
||||
'invert',
|
||||
|
|
|
@ -184,6 +184,8 @@ tensor_operator_registry.register('is_floating_point', is_floating_point)
|
|||
tensor_operator_registry.register('bitwise_and', bitwise_and)
|
||||
tensor_operator_registry.register('bitwise_or', bitwise_or)
|
||||
tensor_operator_registry.register('bitwise_xor', bitwise_xor)
|
||||
tensor_operator_registry.register('bitwise_left_shift', bitwise_left_shift)
|
||||
tensor_operator_registry.register('bitwise_right_shift', bitwise_right_shift)
|
||||
tensor_operator_registry.register('ger', ger)
|
||||
tensor_operator_registry.register('reduce_max', P.ReduceMax)
|
||||
tensor_operator_registry.register('reduce_min', P.ReduceMin)
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Left(nn.Cell):
|
||||
def construct(self, x, other):
|
||||
return ops.bitwise_left_shift(x, other)
|
||||
|
||||
|
||||
class Right(nn.Cell):
|
||||
def construct(self, x, other):
|
||||
return ops.bitwise_right_shift(x, other)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_left_shift_right_shift_normal(mode):
|
||||
"""
|
||||
Feature: tensor.bitwise_left_shift & tensor.bitwise_right_shift
|
||||
Description: Verify the result of the above tensor apis
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = ms.Tensor(np.array([[1024, 5, 6]]), ms.int32)
|
||||
left = Left()
|
||||
right = Right()
|
||||
|
||||
other_left = ms.Tensor(np.array([2]), ms.int8)
|
||||
other_right = ms.Tensor(np.array([1]), ms.int8)
|
||||
|
||||
left_output = left(x, other_left)
|
||||
right_output = right(x, other_right)
|
||||
|
||||
expected_left = np.array([4096, 20, 24], np.int32)
|
||||
expected_right = np.array([512, 2, 3], np.int32)
|
||||
|
||||
assert left_output.dtype == x.dtype
|
||||
assert right_output.dtype == x.dtype
|
||||
assert np.allclose(left_output.asnumpy(), expected_left)
|
||||
assert np.allclose(right_output.asnumpy(), expected_right)
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Left(nn.Cell):
|
||||
def construct(self, x, other):
|
||||
return x.bitwise_left_shift(other)
|
||||
|
||||
|
||||
class Right(nn.Cell):
|
||||
def construct(self, x, other):
|
||||
return x.bitwise_right_shift(other)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_left_shift_right_shift_normal(mode):
|
||||
"""
|
||||
Feature: tensor.bitwise_left_shift & tensor.bitwise_right_shift
|
||||
Description: Verify the result of the above tensor apis
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = ms.Tensor(np.array([[1024, 5, 6]]), ms.int32)
|
||||
left = Left()
|
||||
right = Right()
|
||||
|
||||
other_left = ms.Tensor(np.array([2]), ms.int8)
|
||||
other_right = ms.Tensor(np.array([1]), ms.int8)
|
||||
|
||||
left_output = left(x, other_left)
|
||||
right_output = right(x, other_right)
|
||||
|
||||
expected_left = np.array([4096, 20, 24], np.int32)
|
||||
expected_right = np.array([512, 2, 3], np.int32)
|
||||
|
||||
assert left_output.dtype == x.dtype
|
||||
assert right_output.dtype == x.dtype
|
||||
assert np.allclose(left_output.asnumpy(), expected_left)
|
||||
assert np.allclose(right_output.asnumpy(), expected_right)
|
Loading…
Reference in New Issue