!26772 [assistant] [ops] Add new array operator Concat

Merge pull request !26772 from TR-nbu/Concat
This commit is contained in:
i-robot 2022-03-28 01:40:03 +00:00 committed by Gitee
commit 789180e429
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 169 additions and 39 deletions

View File

@ -30,7 +30,7 @@ namespace kernel {
* }
*/
std::map<string, std::vector<std::pair<string, size_t>>> AicpuOpAttrToInputMap = {
{prim::kPrimOneHot->name(), {{"depth", 1}}}};
{prim::kPrimOneHot->name(), {{"depth", 1}}}, {prim::kPrimConcat->name(), {{"axis", 0}}}};
bool GetAicpuOpAttrToInputInfo(const CNodePtr &kernel_node, std::vector<std::pair<string, size_t>> *info) {
std::string op_name = common::AnfAlgo::GetCNodeName(kernel_node);

View File

@ -97,6 +97,8 @@ bool ConcatCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
}
std::vector<std::pair<KernelAttr, ConcatCpuKernelMod::ConcatFunc>> ConcatCpuKernelMod::func_list_ = {
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&ConcatCpuKernelMod::LaunchKernel<float16>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&ConcatCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
@ -117,6 +119,10 @@ std::vector<std::pair<KernelAttr, ConcatCpuKernelMod::ConcatFunc>> ConcatCpuKern
&ConcatCpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
&ConcatCpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&ConcatCpuKernelMod::LaunchKernel<complex64>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&ConcatCpuKernelMod::LaunchKernel<complex128>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&ConcatCpuKernelMod::LaunchKernel<bool>}};

View File

@ -20,12 +20,16 @@
#include <vector>
#include <memory>
#include <utility>
#include <complex>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
class ConcatCpuKernelMod : public NativeCpuKernelMod {
public:
ConcatCpuKernelMod() = default;

View File

@ -120,6 +120,7 @@ constexpr auto kZerosLike = "ZerosLike";
constexpr auto kOnes = "Ones";
constexpr auto kOnesLike = "OnesLike";
constexpr auto kIdentity = "Identity";
constexpr auto kConcat = "Concat";
constexpr auto kDiag = "Diag";
constexpr auto kDiagPart = "DiagPart";
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
@ -265,7 +266,7 @@ GVAR_DEF(PrimitivePtr, kPrimBroadcastShape, std::make_shared<Primitive>("broadca
GVAR_DEF(PrimitivePtr, kPrimArrayMap, std::make_shared<Primitive>("array_map"));
GVAR_DEF(PrimitivePtr, kPrimArrayReduce, std::make_shared<Primitive>("array_reduce"));
GVAR_DEF(PrimitivePtr, kPrimCast, std::make_shared<Primitive>("Cast"));
GVAR_DEF(PrimitivePtr, kPrimConcat, std::make_shared<Primitive>("Concat"));
GVAR_DEF(PrimitivePtr, kPrimConcat, std::make_shared<Primitive>(kConcat));
GVAR_DEF(PrimitivePtr, kPrimSqueeze, std::make_shared<Primitive>("Squeeze"));
GVAR_DEF(PrimitivePtr, kPrimUnsqueeze, std::make_shared<Primitive>("Unsqueeze"));
GVAR_DEF(PrimitivePtr, kPrimTranspose, std::make_shared<Primitive>(kTranspose));

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,19 +19,106 @@
#include "ops/concat.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
MIND_API_BASE_IMPL(Concat, PrimitiveC, BaseOperator);
namespace {
abstract::ShapePtr ConcatInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t kOneNum = 1;
auto x_shape_ptr = input_args[0]->isa<abstract::AbstractTuple>()
? input_args[0]->cast<abstract::AbstractTuplePtr>()->BuildShape()
: input_args[0]->cast<abstract::AbstractListPtr>()->BuildShape();
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
(void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, kOneNum,
prim_name);
(void)primitive->AddAttr("N", MakeValue(SizeToLong(elements.size())));
(void)primitive->AddAttr("inputNums", MakeValue(SizeToLong(elements.size())));
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(element0);
auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape];
auto element0_rank = element0_shape.size();
auto axis_temp = GetValue<int64_t>(primitive->GetAttr(kAxis));
CheckAndConvertUtils::CheckInRange<int64_t>("Concat axis", axis_temp, kIncludeBoth,
{-SizeToLong(element0_rank), SizeToLong(element0_rank) - kOneNum},
prim_name);
auto axis = axis_temp < 0 ? LongToSize(axis_temp + element0_rank) : LongToSize(axis_temp);
int64_t all_shp = element0_shape[axis];
for (size_t i = 1; i < elements.size(); ++i) {
std::string elementi = "element" + std::to_string(i);
auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger(elementi + " shape rank", SizeToLong(elementi_shape.size()), kEqual,
SizeToLong(element0_shape.size()), prim_name);
for (size_t j = 0; j < element0_rank; ++j) {
if (j != axis && elementi_shape[j] != element0_shape[j]) {
MS_LOG(EXCEPTION) << "For '" << prim_name << "', element " << i
<< " shape in input should concat with first element, but it can not.";
}
}
all_shp = all_shp == -1 || elementi_shape[axis] == -1 ? -1 : all_shp + elementi_shape[axis];
}
auto ret_shape = element0_shape;
ret_shape[axis] = all_shp;
if (x_shape_ptr->IsDynamic()) {
auto element0_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMaxShape];
auto element0_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMinShape];
auto ret_max_shape = element0_max_shape;
auto ret_min_shape = element0_min_shape;
for (size_t i = 1; i < elements.size(); ++i) {
auto elementi_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMaxShape];
auto elementi_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMinShape];
ret_max_shape[axis] += elementi_max_shape[axis];
ret_min_shape[axis] += elementi_min_shape[axis];
}
return std::make_shared<abstract::Shape>(ret_shape, ret_min_shape, ret_max_shape);
} else {
return std::make_shared<abstract::Shape>(ret_shape);
}
}
TypePtr ConcatInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
if (!input_args[0]->isa<abstract::AbstractTuple>() && !input_args[0]->isa<abstract::AbstractList>()) {
MS_EXCEPTION(TypeError) << "The input of Concat must be list or tuple of tensors.";
}
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
std::map<std::string, TypePtr> types;
for (size_t i = 0; i < elements.size(); ++i) {
std::string elementi = "element" + std::to_string(i);
(void)types.emplace(elementi, elements[i]->BuildType());
}
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex_and_bool, prim_name);
return elements[0]->BuildType();
}
} // namespace
void Concat::Init(const int64_t axis) { this->set_axis(axis); }
int64_t Concat::get_axis() const {
auto value_ptr = this->GetAttr(kAxis);
return GetValue<int64_t>(value_ptr);
}
void Concat::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
REGISTER_PRIMITIVE_C(kNameConcat, Concat);
MIND_API_BASE_IMPL(Concat, PrimitiveC, BaseOperator);
AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
const int64_t kInputNum = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type = ConcatInferType(primitive, input_args);
auto infer_shape = ConcatInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Concat, prim::kPrimConcat, ConcatInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,7 +31,7 @@ class MIND_API Concat : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Concat);
/// \brief Constructor.
Concat() : BaseOperator(kNameConcat) {}
Concat() : BaseOperator(kNameConcat) { InitIOName({"x"}, {"y"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Concat for the inputs.
void Init(const int64_t axis = 0);
/// \brief Set axis.

View File

@ -116,6 +116,7 @@ from .lower_bound import _lower_bound_aicpu
from .upper_bound import _upper_bound_aicpu
from .zeros_like import _zeros_like_aicpu
from .ones_like import _ones_like_aicpu
from .concat import _concat_aicpu
from .grid_sampler_3d import _grid_sampler_3d_aicpu
from .grid_sampler_3d_grad import _grid_sampler_3d_grad_aicpu
from .environ_create import _environ_create_aicpu

View File

@ -0,0 +1,57 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Concat op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
concat_op_info = AiCPURegOp("Concat") \
.fusion_type("OPAQUE") \
.input(0, "concat_dim", "required") \
.input(1, "x", "dynamic") \
.output(0, "y", "required") \
.attr("N", "int") \
.dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I32_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I32_Default, DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.I32_Default, DataType.C128_Default, DataType.C128_Default) \
.dtype_format(DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I64_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I64_Default, DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I64_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I64_Default, DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.I64_Default, DataType.C128_Default, DataType.C128_Default) \
.get_op_info()
@op_info_register(concat_op_info)
def _concat_aicpu():
"""Concat AiCPU register"""
return

View File

@ -28,7 +28,6 @@ from mindspore import context
from mindspore.common.initializer import Zero
from .. import signature as sig
from .._utils import get_broadcast_shape, is_shape_unknown
from .._utils import get_concat_offset
from ..operations.math_ops import _infer_shape_reduce
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
from ..._checkparam import Rel
@ -2567,7 +2566,7 @@ class UnsortedSegmentProd(PrimitiveWithInfer):
return out
class Concat(PrimitiveWithInfer):
class Concat(Primitive):
r"""
Connect tensor in the specified axis.
@ -2601,6 +2600,10 @@ class Concat(PrimitiveWithInfer):
Raises:
TypeError: If `axis` is not an int.
TypeError: If `input_x` have different type of tensor.
ValueError: If `input_x` have different dimension of tensor.
ValueError: If `axis` not in [-dims, dims - 1].
RuntimeError: If tensor's shape in `input_x` except for `axis` are different.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -2627,35 +2630,6 @@ class Concat(PrimitiveWithInfer):
"""Initialize Concat"""
validator.check_value_type("axis", axis, [int], self.name)
def __infer__(self, input_x):
axis = self.axis
x_shp = input_x['shape']
x_type = input_x['dtype']
_, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name)
self.add_prim_attr('inputNums', len(x_shp))
ret_shp = x_shp[0].copy()
value = None
if input_x['value'] is not None:
value = Tensor(np.concatenate([x.asnumpy() for x in input_x['value']], axis=axis))
ret_shp[axis] = all_shp
out = {'shape': ret_shp,
'dtype': x_type[0],
'value': value}
if -1 in x_shp[0]:
x_min_shp = input_x['min_shape']
ret_min_shp = x_min_shp[0].copy()
ret_min_shp[axis] = 0
for all_min_shp in x_min_shp:
ret_min_shp[axis] += all_min_shp[axis]
out['min_shape'] = ret_min_shp
x_max_shp = input_x['max_shape']
ret_max_shp = x_max_shp[0].copy()
ret_max_shp[axis] = 0
for all_max_shp in x_max_shp:
ret_max_shp[axis] += all_max_shp[axis]
out['max_shape'] = ret_max_shp
return out
class ParallelConcat(PrimitiveWithInfer):
r"""