forked from mindspore-Ecosystem/mindspore
!26772 [assistant] [ops] Add new array operator Concat
Merge pull request !26772 from TR-nbu/Concat
This commit is contained in:
commit
789180e429
|
@ -30,7 +30,7 @@ namespace kernel {
|
|||
* }
|
||||
*/
|
||||
std::map<string, std::vector<std::pair<string, size_t>>> AicpuOpAttrToInputMap = {
|
||||
{prim::kPrimOneHot->name(), {{"depth", 1}}}};
|
||||
{prim::kPrimOneHot->name(), {{"depth", 1}}}, {prim::kPrimConcat->name(), {{"axis", 0}}}};
|
||||
|
||||
bool GetAicpuOpAttrToInputInfo(const CNodePtr &kernel_node, std::vector<std::pair<string, size_t>> *info) {
|
||||
std::string op_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
|
|
|
@ -97,6 +97,8 @@ bool ConcatCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
|
|||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, ConcatCpuKernelMod::ConcatFunc>> ConcatCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&ConcatCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&ConcatCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
|
@ -117,6 +119,10 @@ std::vector<std::pair<KernelAttr, ConcatCpuKernelMod::ConcatFunc>> ConcatCpuKern
|
|||
&ConcatCpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
&ConcatCpuKernelMod::LaunchKernel<uint64_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
&ConcatCpuKernelMod::LaunchKernel<complex64>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
&ConcatCpuKernelMod::LaunchKernel<complex128>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
&ConcatCpuKernelMod::LaunchKernel<bool>}};
|
||||
|
||||
|
|
|
@ -20,12 +20,16 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <complex>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using complex64 = std::complex<float>;
|
||||
using complex128 = std::complex<double>;
|
||||
|
||||
class ConcatCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
ConcatCpuKernelMod() = default;
|
||||
|
|
|
@ -120,6 +120,7 @@ constexpr auto kZerosLike = "ZerosLike";
|
|||
constexpr auto kOnes = "Ones";
|
||||
constexpr auto kOnesLike = "OnesLike";
|
||||
constexpr auto kIdentity = "Identity";
|
||||
constexpr auto kConcat = "Concat";
|
||||
constexpr auto kDiag = "Diag";
|
||||
constexpr auto kDiagPart = "DiagPart";
|
||||
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
|
||||
|
@ -265,7 +266,7 @@ GVAR_DEF(PrimitivePtr, kPrimBroadcastShape, std::make_shared<Primitive>("broadca
|
|||
GVAR_DEF(PrimitivePtr, kPrimArrayMap, std::make_shared<Primitive>("array_map"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimArrayReduce, std::make_shared<Primitive>("array_reduce"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCast, std::make_shared<Primitive>("Cast"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimConcat, std::make_shared<Primitive>("Concat"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimConcat, std::make_shared<Primitive>(kConcat));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSqueeze, std::make_shared<Primitive>("Squeeze"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimUnsqueeze, std::make_shared<Primitive>("Unsqueeze"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTranspose, std::make_shared<Primitive>(kTranspose));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,19 +19,106 @@
|
|||
#include "ops/concat.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_BASE_IMPL(Concat, PrimitiveC, BaseOperator);
|
||||
namespace {
|
||||
abstract::ShapePtr ConcatInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t kOneNum = 1;
|
||||
auto x_shape_ptr = input_args[0]->isa<abstract::AbstractTuple>()
|
||||
? input_args[0]->cast<abstract::AbstractTuplePtr>()->BuildShape()
|
||||
: input_args[0]->cast<abstract::AbstractListPtr>()->BuildShape();
|
||||
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
|
||||
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
|
||||
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
|
||||
(void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, kOneNum,
|
||||
prim_name);
|
||||
(void)primitive->AddAttr("N", MakeValue(SizeToLong(elements.size())));
|
||||
(void)primitive->AddAttr("inputNums", MakeValue(SizeToLong(elements.size())));
|
||||
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(element0);
|
||||
auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape];
|
||||
auto element0_rank = element0_shape.size();
|
||||
auto axis_temp = GetValue<int64_t>(primitive->GetAttr(kAxis));
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("Concat axis", axis_temp, kIncludeBoth,
|
||||
{-SizeToLong(element0_rank), SizeToLong(element0_rank) - kOneNum},
|
||||
prim_name);
|
||||
auto axis = axis_temp < 0 ? LongToSize(axis_temp + element0_rank) : LongToSize(axis_temp);
|
||||
int64_t all_shp = element0_shape[axis];
|
||||
for (size_t i = 1; i < elements.size(); ++i) {
|
||||
std::string elementi = "element" + std::to_string(i);
|
||||
auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger(elementi + " shape rank", SizeToLong(elementi_shape.size()), kEqual,
|
||||
SizeToLong(element0_shape.size()), prim_name);
|
||||
for (size_t j = 0; j < element0_rank; ++j) {
|
||||
if (j != axis && elementi_shape[j] != element0_shape[j]) {
|
||||
MS_LOG(EXCEPTION) << "For '" << prim_name << "', element " << i
|
||||
<< " shape in input should concat with first element, but it can not.";
|
||||
}
|
||||
}
|
||||
all_shp = all_shp == -1 || elementi_shape[axis] == -1 ? -1 : all_shp + elementi_shape[axis];
|
||||
}
|
||||
auto ret_shape = element0_shape;
|
||||
ret_shape[axis] = all_shp;
|
||||
if (x_shape_ptr->IsDynamic()) {
|
||||
auto element0_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMaxShape];
|
||||
auto element0_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMinShape];
|
||||
auto ret_max_shape = element0_max_shape;
|
||||
auto ret_min_shape = element0_min_shape;
|
||||
for (size_t i = 1; i < elements.size(); ++i) {
|
||||
auto elementi_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMaxShape];
|
||||
auto elementi_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMinShape];
|
||||
ret_max_shape[axis] += elementi_max_shape[axis];
|
||||
ret_min_shape[axis] += elementi_min_shape[axis];
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(ret_shape, ret_min_shape, ret_max_shape);
|
||||
} else {
|
||||
return std::make_shared<abstract::Shape>(ret_shape);
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr ConcatInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
if (!input_args[0]->isa<abstract::AbstractTuple>() && !input_args[0]->isa<abstract::AbstractList>()) {
|
||||
MS_EXCEPTION(TypeError) << "The input of Concat must be list or tuple of tensors.";
|
||||
}
|
||||
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
|
||||
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
|
||||
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
|
||||
std::map<std::string, TypePtr> types;
|
||||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
std::string elementi = "element" + std::to_string(i);
|
||||
(void)types.emplace(elementi, elements[i]->BuildType());
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex_and_bool, prim_name);
|
||||
return elements[0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void Concat::Init(const int64_t axis) { this->set_axis(axis); }
|
||||
int64_t Concat::get_axis() const {
|
||||
auto value_ptr = this->GetAttr(kAxis);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void Concat::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
|
||||
|
||||
REGISTER_PRIMITIVE_C(kNameConcat, Concat);
|
||||
MIND_API_BASE_IMPL(Concat, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
const int64_t kInputNum = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto infer_type = ConcatInferType(primitive, input_args);
|
||||
auto infer_shape = ConcatInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Concat, prim::kPrimConcat, ConcatInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -31,7 +31,7 @@ class MIND_API Concat : public BaseOperator {
|
|||
public:
|
||||
MIND_API_BASE_MEMBER(Concat);
|
||||
/// \brief Constructor.
|
||||
Concat() : BaseOperator(kNameConcat) {}
|
||||
Concat() : BaseOperator(kNameConcat) { InitIOName({"x"}, {"y"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Concat for the inputs.
|
||||
void Init(const int64_t axis = 0);
|
||||
/// \brief Set axis.
|
||||
|
|
|
@ -116,6 +116,7 @@ from .lower_bound import _lower_bound_aicpu
|
|||
from .upper_bound import _upper_bound_aicpu
|
||||
from .zeros_like import _zeros_like_aicpu
|
||||
from .ones_like import _ones_like_aicpu
|
||||
from .concat import _concat_aicpu
|
||||
from .grid_sampler_3d import _grid_sampler_3d_aicpu
|
||||
from .grid_sampler_3d_grad import _grid_sampler_3d_grad_aicpu
|
||||
from .environ_create import _environ_create_aicpu
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Concat op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
concat_op_info = AiCPURegOp("Concat") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "concat_dim", "required") \
|
||||
.input(1, "x", "dynamic") \
|
||||
.output(0, "y", "required") \
|
||||
.attr("N", "int") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.C64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.C128_Default, DataType.C128_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.C64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.C128_Default, DataType.C128_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(concat_op_info)
|
||||
def _concat_aicpu():
|
||||
"""Concat AiCPU register"""
|
||||
return
|
|
@ -28,7 +28,6 @@ from mindspore import context
|
|||
from mindspore.common.initializer import Zero
|
||||
from .. import signature as sig
|
||||
from .._utils import get_broadcast_shape, is_shape_unknown
|
||||
from .._utils import get_concat_offset
|
||||
from ..operations.math_ops import _infer_shape_reduce
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
|
||||
from ..._checkparam import Rel
|
||||
|
@ -2567,7 +2566,7 @@ class UnsortedSegmentProd(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class Concat(PrimitiveWithInfer):
|
||||
class Concat(Primitive):
|
||||
r"""
|
||||
Connect tensor in the specified axis.
|
||||
|
||||
|
@ -2601,6 +2600,10 @@ class Concat(PrimitiveWithInfer):
|
|||
|
||||
Raises:
|
||||
TypeError: If `axis` is not an int.
|
||||
TypeError: If `input_x` have different type of tensor.
|
||||
ValueError: If `input_x` have different dimension of tensor.
|
||||
ValueError: If `axis` not in [-dims, dims - 1].
|
||||
RuntimeError: If tensor's shape in `input_x` except for `axis` are different.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
@ -2627,35 +2630,6 @@ class Concat(PrimitiveWithInfer):
|
|||
"""Initialize Concat"""
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
|
||||
def __infer__(self, input_x):
|
||||
axis = self.axis
|
||||
x_shp = input_x['shape']
|
||||
x_type = input_x['dtype']
|
||||
_, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name)
|
||||
self.add_prim_attr('inputNums', len(x_shp))
|
||||
ret_shp = x_shp[0].copy()
|
||||
value = None
|
||||
if input_x['value'] is not None:
|
||||
value = Tensor(np.concatenate([x.asnumpy() for x in input_x['value']], axis=axis))
|
||||
ret_shp[axis] = all_shp
|
||||
out = {'shape': ret_shp,
|
||||
'dtype': x_type[0],
|
||||
'value': value}
|
||||
if -1 in x_shp[0]:
|
||||
x_min_shp = input_x['min_shape']
|
||||
ret_min_shp = x_min_shp[0].copy()
|
||||
ret_min_shp[axis] = 0
|
||||
for all_min_shp in x_min_shp:
|
||||
ret_min_shp[axis] += all_min_shp[axis]
|
||||
out['min_shape'] = ret_min_shp
|
||||
x_max_shp = input_x['max_shape']
|
||||
ret_max_shp = x_max_shp[0].copy()
|
||||
ret_max_shp[axis] = 0
|
||||
for all_max_shp in x_max_shp:
|
||||
ret_max_shp[axis] += all_max_shp[axis]
|
||||
out['max_shape'] = ret_max_shp
|
||||
return out
|
||||
|
||||
|
||||
class ParallelConcat(PrimitiveWithInfer):
|
||||
r"""
|
||||
|
|
Loading…
Reference in New Issue