!49439 sequence addn cpu kernel

Merge pull request !49439 from chenweifeng/sequence-addn
This commit is contained in:
i-robot 2023-02-27 12:44:52 +00:00 committed by Gitee
commit 21e787b00e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 417 additions and 0 deletions

View File

@ -699,6 +699,7 @@ constexpr auto kSegmentProdOpName = "SegmentProd";
constexpr auto kSegmentSumOpName = "SegmentSum";
constexpr auto kSequenceAddOpName = "SequenceAdd";
constexpr auto kSequenceAddOffsetOpName = "SequenceAddOffset";
constexpr auto kSequenceAddNOpName = "SequenceAddN";
constexpr auto kSelectOpName = "Select";
constexpr auto kSelfAdjointEigOpName = "SelfAdjointEig";
constexpr auto kSeLUOpName = "SeLU";

View File

@ -206,6 +206,7 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
{"ScalarSub", kAnyType},
{"SelfAdjointEig", kTupleTensor2},
{"SequenceAdd", kAnyType},
{"SequenceAddN", kAnyType},
{"SequenceCount", kAnyType},
{"SequenceIndex", kAnyType},
{"SequenceMul", kAnyType},

View File

@ -0,0 +1,129 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/sequence/sequence_addn_cpu_kernel.h"
#include <algorithm>
#include <utility>
#include <complex>
#include <functional>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/device/cpu/kernel/nnacl/fp32/add_fp32.h"
#include "plugin/device/cpu/kernel/nnacl/errorcode.h"
#include "utils/ms_utils.h"
#include "include/common/thread_pool.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr int kInputsNum = 1;
constexpr int kOutputsNum = 1;
using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
template <typename T>
void Add(const T *in0, const T *in1, T *out, int start, int end) {
for (int index = start; index < end; index++) {
out[index] = in0[index] + in1[index];
}
}
template <>
void Add(const int *in_0, const int *in_1, int *out, int start, int end) {
int ret = ElementAddInt(in_0 + start, in_1 + start, out + start, end - start);
if (ret != NNACL_OK) {
MS_LOG(EXCEPTION) << "For 'AddN', AddInt failed.";
}
}
template <>
void Add(const float *in_0, const float *in_1, float *out, int start, int end) {
int ret = ElementAdd(in_0 + start, in_1 + start, out + start, end - start);
if (ret != NNACL_OK) {
MS_LOG(EXCEPTION) << "For 'AddN', AddFloat failed.";
}
}
} // namespace
bool SequenceAddNCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
return MatchKernelFunc(base_operator, inputs, outputs);
}
int SequenceAddNCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
if (ret != 0) {
return ret;
}
tuple_shape_ = inputs[0]->GetShapeVector();
if (tuple_shape_.empty()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << " the input tuple size must greater 0";
}
return KRET_OK;
}
template <typename T>
bool SequenceAddNCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
size_t elements_num = outputs[0]->size / sizeof(T);
const auto input_0 = reinterpret_cast<T *>(inputs[0]->addr);
auto input_1 = input_0 + elements_num;
auto output = reinterpret_cast<T *>(outputs[0]->addr);
auto task_0 = std::bind(Add<T>, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2);
ParallelLaunchAutoSearch(task_0, elements_num, this, &parallel_search_info_);
for (int64_t index = 2; index < tuple_shape_[0]; ++index) {
input_1 += elements_num;
auto task = std::bind(Add<T>, input_1, output, output, std::placeholders::_1, std::placeholders::_2);
ParallelLaunchAutoSearch(task, elements_num, this, &parallel_search_info_);
}
return true;
}
#define SEQUENCE_ADDN_REG(ms_type, builtin_type) \
{ \
KernelAttr().AddInputAttr(kObjectTypeTuple, ms_type).AddOutputAttr(kObjectTypeNumber, ms_type), \
&SequenceAddNCpuKernelMod::LaunchKernel<builtin_type> \
}, \
{ \
KernelAttr().AddInputAttr(kObjectTypeTuple, ms_type).AddOutputAttr(ms_type), \
&SequenceAddNCpuKernelMod::LaunchKernel<builtin_type> \
}
const SequenceAddNCpuKernelMod::FuncList &SequenceAddNCpuKernelMod::GetFuncList() const {
static const FuncList func_list = {
SEQUENCE_ADDN_REG(kNumberTypeInt8, int8_t), SEQUENCE_ADDN_REG(kNumberTypeInt16, int16_t),
SEQUENCE_ADDN_REG(kNumberTypeInt32, int32_t), SEQUENCE_ADDN_REG(kNumberTypeInt64, int64_t),
SEQUENCE_ADDN_REG(kNumberTypeUInt8, uint8_t), SEQUENCE_ADDN_REG(kNumberTypeUInt16, uint16_t),
SEQUENCE_ADDN_REG(kNumberTypeUInt32, uint32_t), SEQUENCE_ADDN_REG(kNumberTypeUInt64, uint64_t),
SEQUENCE_ADDN_REG(kNumberTypeFloat16, float16), SEQUENCE_ADDN_REG(kNumberTypeFloat32, float),
SEQUENCE_ADDN_REG(kNumberTypeFloat64, double), SEQUENCE_ADDN_REG(kNumberTypeComplex64, complex64),
SEQUENCE_ADDN_REG(kNumberTypeComplex128, complex128)};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SequenceAddN, SequenceAddNCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_ADDN_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_ADDN_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include <utility>
#include <map>
#include <string>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class SequenceAddNCpuKernelMod : public NativeCpuKernelMod,
public MatchKernelHelper<SequenceAddNCpuKernelMod, AddressPtr> {
public:
SequenceAddNCpuKernelMod() = default;
~SequenceAddNCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
MS_EXCEPTION_IF_NULL(kernel_func_);
return kernel_func_(this, inputs, workspace, outputs);
}
using FuncList = std::vector<std::pair<KernelAttr, KernelRunFunc>>;
const FuncList &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
std::vector<int64_t> tuple_shape_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_ADDN_CPU_KERNEL_H_

View File

@ -291,6 +291,7 @@ constexpr auto kSequenceSliceGrad = "SequenceSliceGrad";
constexpr auto kSequenceSliceSetItem = "SequenceSliceSetItem";
constexpr auto kSequenceMax = "SequenceMax";
constexpr auto kSequenceMin = "SequenceMin";
constexpr auto kSequenceAddN = "SequenceAddN";
// NN
constexpr auto kFractionalMaxPoolWithFixedKsize = "FractionalMaxPoolWithFixedKsize";
@ -1638,6 +1639,7 @@ GVAR_DEF(PrimitivePtr, kPrimSequenceAddOffset, std::make_shared<Primitive>(kSequ
GVAR_DEF(PrimitivePtr, kPrimSequenceSliceGrad, std::make_shared<Primitive>(kSequenceSliceGrad));
GVAR_DEF(PrimitivePtr, kPrimSequenceMax, std::make_shared<Primitive>(kSequenceMax));
GVAR_DEF(PrimitivePtr, kPrimSequenceMin, std::make_shared<Primitive>(kSequenceMin));
GVAR_DEF(PrimitivePtr, kPrimSequenceAddN, std::make_shared<Primitive>(kSequenceAddN));
// Other miscellaneous
GVAR_DEF(PrimitivePtr, kPrimSampleDistortedBoundingBoxV2, std::make_shared<Primitive>(kSampleDistortedBoundingBoxV2));

View File

@ -0,0 +1,79 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/sequence_addn.h"
#include <vector>
#include <string>
#include <memory>
#include "utils/check_convert_utils.h"
#include "abstract/abstract_value.h"
#include "abstract/ops/op_infer.h"
#include "abstract/ops/primitive_infer_map.h"
#include "base/base.h"
#include "ir/anf.h"
#include "ir/primitive.h"
#include "ops/core_ops.h"
#include "ops/primitive_c.h"
#include "utils/convert_utils_base.h"
#include "utils/log_adapter.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
AbstractBasePtr SequenceAddNInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
// Inputs: a tuple or list and a scalar whose value is an int32 number.
constexpr int args_spec_size = 1;
abstract::CheckArgsSize(op_name, input_args, args_spec_size);
auto queue = abstract::CheckArg<abstract::AbstractSequence>(op_name, input_args, 0);
// The value of dynamic_len_element_abs is kAnyValue, do not need to Broaden.
if (queue->dynamic_len()) {
auto element_abs = queue->dynamic_len_element_abs();
MS_EXCEPTION_IF_NULL(element_abs);
return element_abs->Clone();
}
if (queue->elements().size() == 0) {
MS_LOG(EXCEPTION) << "Sequence length should not be 0.";
}
return queue->elements()[0];
}
} // namespace
MIND_API_OPERATOR_IMPL(SequenceAddN, BaseOperator);
class SequenceAddNInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceAddNInferInner(primitive, input_args)->BuildShape();
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceAddNInferInner(prim, input_args)->BuildType();
}
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
return SequenceAddNInferInner(primitive, input_args);
}
};
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceAddN, prim::kPrimSequenceAddN, SequenceAddNInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SEQUENCE_ADDN_H_
#define MINDSPORE_CORE_OPS_SEQUENCE_ADDN_H_
#include "ops/base_operator.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace ops {
/// \brief Sequence addition operation
class MIND_API SequenceAddN : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SequenceAddN);
/// \brief Constructor.
SequenceAddN() : BaseOperator(prim::kSequenceAddN) {}
/// \brief Init function.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SEQUENCE_ADDN_H_

View File

@ -597,6 +597,31 @@ class SequenceMin(Primitive):
self.init_prim_io_names(inputs=['sequence'], outputs=['output_data'])
class SequenceAddN(Primitive):
r"""
Support sequence AddN operation.
.. note::
This it is only for internal used.
Inputs:
- **sequence** (Union[List, Tuple]) - The sequence.
Outputs:
The addition of all input.
Raises:
TypeError: The 'sequence' is not list or tuple.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register
def __init__(self):
"""Initialize SequenceAddN"""
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
class tuple_greater_than(Primitive):
r"""
Support tuple_greater_than operation 'greater_than(target)'.

View File

@ -0,0 +1,82 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.common import mutable
from mindspore.ops.operations._sequence_ops import SequenceAddN
from sequence_help import context_prepare
context.set_context(mode=context.GRAPH_MODE)
context_prepare()
class NetSequenceAddN(nn.Cell):
def __init__(self):
super().__init__()
self.op = SequenceAddN()
def construct(self, seq):
return self.op(seq)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_seq_tensor_addn():
"""
Feature: test sequence addn op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
seq = mutable((Tensor(1), Tensor(2), Tensor(3), Tensor(4)), True)
expect = Tensor(10)
net = NetSequenceAddN()
res = net(seq)
assert np.all(res.asnumpy() == expect.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_seq_tensor_addn1():
"""
Feature: test sequence addn op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
seq = mutable((Tensor([[1, 2], [2, 3]]), Tensor([[2, 3], [3, 4]]), Tensor([[3, 4], [4, 5]])), True)
expect = Tensor([[6, 9], [9, 12]])
net = NetSequenceAddN()
res = net(seq)
assert np.all(res.asnumpy() == expect.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_seq_addn():
"""
Feature: test sequence addn op
Description: setitem operation on tuple type
Expectation: the behavior is matched to python style
"""
seq = mutable((1, 2, 3, 4, 5, 6), True)
expect = 21
net = NetSequenceAddN()
res = net(seq)
assert res == expect