forked from mindspore-Ecosystem/mindspore
!49439 sequence addn cpu kernel
Merge pull request !49439 from chenweifeng/sequence-addn
This commit is contained in:
commit
21e787b00e
|
@ -699,6 +699,7 @@ constexpr auto kSegmentProdOpName = "SegmentProd";
|
|||
constexpr auto kSegmentSumOpName = "SegmentSum";
|
||||
constexpr auto kSequenceAddOpName = "SequenceAdd";
|
||||
constexpr auto kSequenceAddOffsetOpName = "SequenceAddOffset";
|
||||
constexpr auto kSequenceAddNOpName = "SequenceAddN";
|
||||
constexpr auto kSelectOpName = "Select";
|
||||
constexpr auto kSelfAdjointEigOpName = "SelfAdjointEig";
|
||||
constexpr auto kSeLUOpName = "SeLU";
|
||||
|
|
|
@ -206,6 +206,7 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
|
|||
{"ScalarSub", kAnyType},
|
||||
{"SelfAdjointEig", kTupleTensor2},
|
||||
{"SequenceAdd", kAnyType},
|
||||
{"SequenceAddN", kAnyType},
|
||||
{"SequenceCount", kAnyType},
|
||||
{"SequenceIndex", kAnyType},
|
||||
{"SequenceMul", kAnyType},
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/sequence/sequence_addn_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <complex>
|
||||
#include <functional>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "plugin/device/cpu/kernel/nnacl/fp32/add_fp32.h"
|
||||
#include "plugin/device/cpu/kernel/nnacl/errorcode.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "include/common/thread_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr int kInputsNum = 1;
|
||||
constexpr int kOutputsNum = 1;
|
||||
|
||||
using complex64 = std::complex<float>;
|
||||
using complex128 = std::complex<double>;
|
||||
|
||||
template <typename T>
|
||||
void Add(const T *in0, const T *in1, T *out, int start, int end) {
|
||||
for (int index = start; index < end; index++) {
|
||||
out[index] = in0[index] + in1[index];
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void Add(const int *in_0, const int *in_1, int *out, int start, int end) {
|
||||
int ret = ElementAddInt(in_0 + start, in_1 + start, out + start, end - start);
|
||||
if (ret != NNACL_OK) {
|
||||
MS_LOG(EXCEPTION) << "For 'AddN', AddInt failed.";
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void Add(const float *in_0, const float *in_1, float *out, int start, int end) {
|
||||
int ret = ElementAdd(in_0 + start, in_1 + start, out + start, end - start);
|
||||
if (ret != NNACL_OK) {
|
||||
MS_LOG(EXCEPTION) << "For 'AddN', AddFloat failed.";
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool SequenceAddNCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
return MatchKernelFunc(base_operator, inputs, outputs);
|
||||
}
|
||||
|
||||
int SequenceAddNCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
tuple_shape_ = inputs[0]->GetShapeVector();
|
||||
if (tuple_shape_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << " the input tuple size must greater 0";
|
||||
}
|
||||
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SequenceAddNCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
size_t elements_num = outputs[0]->size / sizeof(T);
|
||||
const auto input_0 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto input_1 = input_0 + elements_num;
|
||||
|
||||
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto task_0 = std::bind(Add<T>, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2);
|
||||
ParallelLaunchAutoSearch(task_0, elements_num, this, ¶llel_search_info_);
|
||||
|
||||
for (int64_t index = 2; index < tuple_shape_[0]; ++index) {
|
||||
input_1 += elements_num;
|
||||
auto task = std::bind(Add<T>, input_1, output, output, std::placeholders::_1, std::placeholders::_2);
|
||||
ParallelLaunchAutoSearch(task, elements_num, this, ¶llel_search_info_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#define SEQUENCE_ADDN_REG(ms_type, builtin_type) \
|
||||
{ \
|
||||
KernelAttr().AddInputAttr(kObjectTypeTuple, ms_type).AddOutputAttr(kObjectTypeNumber, ms_type), \
|
||||
&SequenceAddNCpuKernelMod::LaunchKernel<builtin_type> \
|
||||
}, \
|
||||
{ \
|
||||
KernelAttr().AddInputAttr(kObjectTypeTuple, ms_type).AddOutputAttr(ms_type), \
|
||||
&SequenceAddNCpuKernelMod::LaunchKernel<builtin_type> \
|
||||
}
|
||||
|
||||
const SequenceAddNCpuKernelMod::FuncList &SequenceAddNCpuKernelMod::GetFuncList() const {
|
||||
static const FuncList func_list = {
|
||||
SEQUENCE_ADDN_REG(kNumberTypeInt8, int8_t), SEQUENCE_ADDN_REG(kNumberTypeInt16, int16_t),
|
||||
SEQUENCE_ADDN_REG(kNumberTypeInt32, int32_t), SEQUENCE_ADDN_REG(kNumberTypeInt64, int64_t),
|
||||
SEQUENCE_ADDN_REG(kNumberTypeUInt8, uint8_t), SEQUENCE_ADDN_REG(kNumberTypeUInt16, uint16_t),
|
||||
SEQUENCE_ADDN_REG(kNumberTypeUInt32, uint32_t), SEQUENCE_ADDN_REG(kNumberTypeUInt64, uint64_t),
|
||||
SEQUENCE_ADDN_REG(kNumberTypeFloat16, float16), SEQUENCE_ADDN_REG(kNumberTypeFloat32, float),
|
||||
SEQUENCE_ADDN_REG(kNumberTypeFloat64, double), SEQUENCE_ADDN_REG(kNumberTypeComplex64, complex64),
|
||||
SEQUENCE_ADDN_REG(kNumberTypeComplex128, complex128)};
|
||||
return func_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SequenceAddN, SequenceAddNCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_ADDN_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_ADDN_CPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SequenceAddNCpuKernelMod : public NativeCpuKernelMod,
|
||||
public MatchKernelHelper<SequenceAddNCpuKernelMod, AddressPtr> {
|
||||
public:
|
||||
SequenceAddNCpuKernelMod() = default;
|
||||
~SequenceAddNCpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_func_);
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
using FuncList = std::vector<std::pair<KernelAttr, KernelRunFunc>>;
|
||||
const FuncList &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
std::vector<int64_t> tuple_shape_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SEQUENCE_ADDN_CPU_KERNEL_H_
|
|
@ -291,6 +291,7 @@ constexpr auto kSequenceSliceGrad = "SequenceSliceGrad";
|
|||
constexpr auto kSequenceSliceSetItem = "SequenceSliceSetItem";
|
||||
constexpr auto kSequenceMax = "SequenceMax";
|
||||
constexpr auto kSequenceMin = "SequenceMin";
|
||||
constexpr auto kSequenceAddN = "SequenceAddN";
|
||||
|
||||
// NN
|
||||
constexpr auto kFractionalMaxPoolWithFixedKsize = "FractionalMaxPoolWithFixedKsize";
|
||||
|
@ -1638,6 +1639,7 @@ GVAR_DEF(PrimitivePtr, kPrimSequenceAddOffset, std::make_shared<Primitive>(kSequ
|
|||
GVAR_DEF(PrimitivePtr, kPrimSequenceSliceGrad, std::make_shared<Primitive>(kSequenceSliceGrad));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSequenceMax, std::make_shared<Primitive>(kSequenceMax));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSequenceMin, std::make_shared<Primitive>(kSequenceMin));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSequenceAddN, std::make_shared<Primitive>(kSequenceAddN));
|
||||
|
||||
// Other miscellaneous
|
||||
GVAR_DEF(PrimitivePtr, kPrimSampleDistortedBoundingBoxV2, std::make_shared<Primitive>(kSampleDistortedBoundingBoxV2));
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/sequence_addn.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "abstract/ops/op_infer.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "base/base.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "ops/core_ops.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
AbstractBasePtr SequenceAddNInferInner(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
// Inputs: a tuple or list and a scalar whose value is an int32 number.
|
||||
constexpr int args_spec_size = 1;
|
||||
abstract::CheckArgsSize(op_name, input_args, args_spec_size);
|
||||
auto queue = abstract::CheckArg<abstract::AbstractSequence>(op_name, input_args, 0);
|
||||
|
||||
// The value of dynamic_len_element_abs is kAnyValue, do not need to Broaden.
|
||||
if (queue->dynamic_len()) {
|
||||
auto element_abs = queue->dynamic_len_element_abs();
|
||||
MS_EXCEPTION_IF_NULL(element_abs);
|
||||
return element_abs->Clone();
|
||||
}
|
||||
|
||||
if (queue->elements().size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "Sequence length should not be 0.";
|
||||
}
|
||||
return queue->elements()[0];
|
||||
}
|
||||
} // namespace
|
||||
MIND_API_OPERATOR_IMPL(SequenceAddN, BaseOperator);
|
||||
class SequenceAddNInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceAddNInferInner(primitive, input_args)->BuildShape();
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceAddNInferInner(prim, input_args)->BuildType();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return SequenceAddNInferInner(primitive, input_args);
|
||||
}
|
||||
};
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(SequenceAddN, prim::kPrimSequenceAddN, SequenceAddNInfer, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SEQUENCE_ADDN_H_
|
||||
#define MINDSPORE_CORE_OPS_SEQUENCE_ADDN_H_
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
/// \brief Sequence addition operation
|
||||
class MIND_API SequenceAddN : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SequenceAddN);
|
||||
/// \brief Constructor.
|
||||
SequenceAddN() : BaseOperator(prim::kSequenceAddN) {}
|
||||
/// \brief Init function.
|
||||
void Init() const {}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SEQUENCE_ADDN_H_
|
|
@ -597,6 +597,31 @@ class SequenceMin(Primitive):
|
|||
self.init_prim_io_names(inputs=['sequence'], outputs=['output_data'])
|
||||
|
||||
|
||||
class SequenceAddN(Primitive):
|
||||
r"""
|
||||
Support sequence AddN operation.
|
||||
|
||||
.. note::
|
||||
This it is only for internal used.
|
||||
|
||||
Inputs:
|
||||
- **sequence** (Union[List, Tuple]) - The sequence.
|
||||
|
||||
Outputs:
|
||||
The addition of all input.
|
||||
|
||||
Raises:
|
||||
TypeError: The 'sequence' is not list or tuple.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SequenceAddN"""
|
||||
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
|
||||
|
||||
|
||||
class tuple_greater_than(Primitive):
|
||||
r"""
|
||||
Support tuple_greater_than operation 'greater_than(target)'.
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.common import mutable
|
||||
from mindspore.ops.operations._sequence_ops import SequenceAddN
|
||||
from sequence_help import context_prepare
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context_prepare()
|
||||
|
||||
|
||||
class NetSequenceAddN(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.op = SequenceAddN()
|
||||
|
||||
def construct(self, seq):
|
||||
return self.op(seq)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_tensor_addn():
|
||||
"""
|
||||
Feature: test sequence addn op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
seq = mutable((Tensor(1), Tensor(2), Tensor(3), Tensor(4)), True)
|
||||
expect = Tensor(10)
|
||||
net = NetSequenceAddN()
|
||||
res = net(seq)
|
||||
assert np.all(res.asnumpy() == expect.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_tensor_addn1():
|
||||
"""
|
||||
Feature: test sequence addn op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
seq = mutable((Tensor([[1, 2], [2, 3]]), Tensor([[2, 3], [3, 4]]), Tensor([[3, 4], [4, 5]])), True)
|
||||
expect = Tensor([[6, 9], [9, 12]])
|
||||
net = NetSequenceAddN()
|
||||
res = net(seq)
|
||||
assert np.all(res.asnumpy() == expect.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_seq_addn():
|
||||
"""
|
||||
Feature: test sequence addn op
|
||||
Description: setitem operation on tuple type
|
||||
Expectation: the behavior is matched to python style
|
||||
"""
|
||||
seq = mutable((1, 2, 3, 4, 5, 6), True)
|
||||
expect = 21
|
||||
net = NetSequenceAddN()
|
||||
res = net(seq)
|
||||
assert res == expect
|
Loading…
Reference in New Issue