!26505 [assistant] [ops] [I48O8K] add new array operator RightShift

Merge pull request !26505 from 王超/rightshift
This commit is contained in:
i-robot 2022-03-29 13:14:24 +00:00 committed by Gitee
commit abf0c022a9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 416 additions and 1 deletions

View File

@ -0,0 +1,157 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/right_shift_cpu_kernel.h"
#include <vector>
#include <cmath>
#include <type_traits>
#include <memory>
#include <functional>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/arithmetic_cpu_kernel.h"
namespace mindspore {
namespace kernel {
namespace {
const size_t kRightShiftInputsNum = 2;
const size_t kRightShiftOutputsNum = 1;
} // namespace
void RightShiftCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kRightShiftInputsNum, common::AnfAlgo::GetCNodeName(kernel_node));
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kRightShiftOutputsNum, common::AnfAlgo::GetCNodeName(kernel_node));
input_type_1_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
input_type_2_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
if (input_type_1_ != input_type_2_) {
MS_LOG(EXCEPTION) << "input1 and input2 must have the same type.";
}
input_shape_1_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_shape_2_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
}
bool RightShiftCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> & /* workspace */,
const std::vector<AddressPtr> &outputs) {
if (input_type_1_ == kNumberTypeInt8) {
return IntCompute<int8_t>(inputs, outputs);
} else if (input_type_1_ == kNumberTypeInt16) {
return IntCompute<int16_t>(inputs, outputs);
} else if (input_type_1_ == kNumberTypeInt32) {
return IntCompute<int32_t>(inputs, outputs);
} else if (input_type_1_ == kNumberTypeInt64) {
return IntCompute<int64_t>(inputs, outputs);
} else if (input_type_1_ == kNumberTypeUInt8) {
return UIntCompute<uint8_t>(inputs, outputs);
} else if (input_type_1_ == kNumberTypeUInt16) {
return UIntCompute<uint16_t>(inputs, outputs);
} else if (input_type_1_ == kNumberTypeUInt32) {
return UIntCompute<uint32_t>(inputs, outputs);
} else if (input_type_1_ == kNumberTypeUInt64) {
return UIntCompute<uint64_t>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the type of 'x' should be int8, int16, int32, int64, uint8, uint16, uint32, uint64, "
"but got "
<< TypeIdLabel(input_type_1_);
}
return true;
}
template <typename T>
bool RightShiftCpuKernelMod::IntCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
auto *input1 = reinterpret_cast<T *>(inputs[0]->addr);
const auto *input2 = reinterpret_cast<T *>(inputs[1]->addr);
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
if (output_shape_.size() == 0) {
(void)output_shape_.insert(output_shape_.begin(), 1);
}
size_t output_size_ = 1;
for (size_t i = 0; i < output_shape_.size(); ++i) {
output_size_ *= output_shape_[i];
}
BroadcastIterator base_iter(input_shape_1_, input_shape_2_, output_shape_);
auto task = [&input1, &input2, &output, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
T y_val = (input2[iter.GetInputPosB()]);
T bit_val = static_cast<T>(sizeof(T) * 8 - 1);
T zero = static_cast<T>(0);
if (y_val <= zero) {
y_val = zero;
} else if (y_val > bit_val) {
y_val = bit_val;
}
output[i] = static_cast<T>(input1[iter.GetInputPosA()] >> y_val);
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
return true;
}
template <typename T>
bool RightShiftCpuKernelMod::UIntCompute(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto *input1 = reinterpret_cast<T *>(inputs[0]->addr);
const auto *input2 = reinterpret_cast<T *>(inputs[1]->addr);
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
if (output_shape_.size() == 0) {
(void)output_shape_.insert(output_shape_.begin(), 1);
}
size_t output_size_ = 1;
for (size_t i = 0; i < output_shape_.size(); ++i) {
output_size_ *= output_shape_[i];
}
BroadcastIterator base_iter(input_shape_1_, input_shape_2_, output_shape_);
auto task = [&input1, &input2, &output, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; i++) {
T y_val = (input2[iter.GetInputPosB()]);
T bit_val = static_cast<T>(sizeof(T) * 8 - 1);
if (y_val > bit_val) {
y_val = bit_val;
}
output[i] = static_cast<T>(input1[iter.GetInputPosA()] >> y_val);
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
return true;
}
std::vector<KernelAttr> RightShiftCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64)};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, RightShift, RightShiftCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RIGHT_SHIFT_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RIGHT_SHIFT_CPU_KERNEL_H_
#include <functional>
#include <memory>
#include <vector>
#include <iostream>
#include <string>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class RightShiftCpuKernelMod : public NativeCpuKernelMod {
public:
RightShiftCpuKernelMod() = default;
~RightShiftCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
TypeId input_type_1_{kTypeUnknown};
TypeId input_type_2_{kTypeUnknown};
std::vector<size_t> input_shape_1_;
std::vector<size_t> input_shape_2_;
std::vector<size_t> output_shape_;
template <typename T>
bool IntCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
bool UIntCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RIGHT_SHIFT_CPU_KERNEL_H_

View File

@ -121,6 +121,7 @@ constexpr auto kOnes = "Ones";
constexpr auto kOnesLike = "OnesLike";
constexpr auto kIdentity = "Identity";
constexpr auto kConcat = "Concat";
constexpr auto kRightShift = "RightShift";
constexpr auto kDiag = "Diag";
constexpr auto kDiagPart = "DiagPart";
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
@ -373,6 +374,7 @@ GVAR_DEF(PrimitivePtr, kPrimLstsq, std::make_shared<Primitive>(kLstsq));
GVAR_DEF(PrimitivePtr, kPrimLowerBound, std::make_shared<Primitive>(kLowerBound));
GVAR_DEF(PrimitivePtr, kPrimUpperBound, std::make_shared<Primitive>(kUpperBound));
GVAR_DEF(PrimitivePtr, kPrimCummax, std::make_shared<Primitive>(kCummax));
GVAR_DEF(PrimitivePtr, kPrimRightShift, std::make_shared<Primitive>(kRightShift));
GVAR_DEF(PrimitivePtr, kPrimTril, std::make_shared<Primitive>(kTril));
// image

View File

@ -0,0 +1,75 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/rightshift.h"
#include <algorithm>
#include <functional>
#include <string>
#include <vector>
#include <memory>
#include <set>
#include "abstract/abstract_value.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "ops/primitive_c.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr RightShiftInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t max_dim = 8;
auto in_shape_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto in_shape_y = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("The dimension of RightShift input", SizeToLong(in_shape_x.size()),
kLessThan, max_dim, prim_name);
(void)CheckAndConvertUtils::CheckInteger("The dimension of RightShift input", SizeToLong(in_shape_y.size()),
kLessThan, max_dim, prim_name);
return BroadCastInferShape(prim_name, input_args);
}
TypePtr RightShiftInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto y = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
(void)abstract::CheckDtypeSame(prim_name, x, y);
auto input_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(input_type);
if (!input_type->isa<TensorType>()) {
MS_EXCEPTION(TypeError) << "The " << prim_name << "'s"
<< " input must be tensor type but got " << input_type->ToString();
}
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_type, valid_types, prim_name);
return input_type;
}
} // namespace
MIND_API_BASE_IMPL(RightShift, PrimitiveC, BaseOperator);
AbstractBasePtr RightShiftInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputsNum = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
auto infer_type = RightShiftInferType(primitive, input_args);
auto infer_shape = RightShiftInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(RightShift, prim::kPrimRightShift, RightShiftInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_RIGHTSHIFT_H_
#define MINDSPORE_CORE_OPS_RIGHTSHIFT_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameRightShift = "RightShift";
/// \brief Shift x to the right by y in element-wise.
/// Refer to Python API @ref mindspore.ops.RightShift for more details.
class MIND_API RightShift : public BaseOperator {
public:
MIND_API_BASE_MEMBER(RightShift);
/// \brief Constructor.
RightShift() : BaseOperator(kNameRightShift) { InitIOName({"input_x", "input_y"}, {"output"}); }
/// \brief Init.
void Init() {}
};
abstract::AbstractBasePtr RightShiftInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimRightShift = std::shared_ptr<RightShift>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_RIGHTSHIFT_H_

View File

@ -133,4 +133,5 @@ from .priority_replay_buffer import _prb_create_op_cpu
from .priority_replay_buffer import _prb_push_op_cpu
from .priority_replay_buffer import _prb_sample_op_cpu
from .priority_replay_buffer import _prb_update_op_cpu
from .right_shift import _right_shift_aicpu
from .tril import _tril_aicpu

View File

@ -0,0 +1,38 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RightShift op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
right_shift_op_info = AiCPURegOp("RightShift") \
.fusion_type("OPAQUE") \
.input(0, "input_x", "required") \
.input(1, "input_y", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
.get_op_info()
@op_info_register(right_shift_op_info)
def _right_shift_aicpu():
"""rightshift aicpu register"""
return

View File

@ -7274,6 +7274,43 @@ class Cummax(Primitive):
self.init_prim_io_names(inputs=['x'], outputs=['y', 'indices'])
class RightShift(Primitive):
r"""
Shift the value of each position of the tensor to the right several bits.
The inputs are two tensors, dtypes of them must be consistent, and the
shapes of them could be broadcast.
.. math::
\begin{aligned}
&out_{i} =x_{i} >> y_{i}
\end{aligned}
Inputs:
- **input_x** (Tensor) - The target tensor, will be shifted to the right
by y in element-wise.
- **input_y** (Tensor) - The tensor must have the same type as input_x.
Outputs:
- **output** (Tensor) - The output tensor, has the same type as input_x.
Raises:
TypeError: If `input_x` or `input_y` is not tensor.
TypeError: If `input_x` and `input_y` could not be broadcast.
>>> rightshift = ops.RightShift()
>>> input_x = Tensor(np.array([1, 2, 3]).astype(np.uint8))
>>> input_y = Tensor(np.array([1, 1, 1]).astype(np.uint8))
>>> output = rightshift(input_x, input_y)
>>> print(output)
[0 1 1]
"""
@prim_attr_register
def __init__(self):
"""Initialize RightShift."""
self.init_prim_io_names(inputs=['input_x', 'input_y'], outputs=['output'])
class Tril(Primitive):
"""
Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices input,

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -34,6 +34,7 @@ from mindspore.ops.operations import nn_ops as nps
from mindspore.ops.operations.array_ops import Tril
from mindspore.ops.operations.random_ops import NonDeterministicInts
from mindspore.nn.layer import normalization
from mindspore.ops.operations.array_ops import RightShift
from mindspore._c_expression import security
from tests.security_utils import security_off_wrap
from ..ut_filter import non_graph_engine
@ -2550,6 +2551,11 @@ test_case_array_ops = [
'desc_const': [(64,)],
'desc_inputs': [[64, 1]],
'desc_bprop': [[64]]}),
('RightShift', {
'block': RightShift(),
'desc_inputs': [Tensor(np.array([1, 2, 3]), mstype.int32),
Tensor(np.array([5, 2, 3]), mstype.int32)],
'skip': ['backward']}),
('Cast', {
'block': P.Cast(),
'desc_const': [mstype.int32],