diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_attr_to_input_registry.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_attr_to_input_registry.cc index 762a9a1d11a..1d95054bbe3 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_attr_to_input_registry.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_attr_to_input_registry.cc @@ -30,7 +30,7 @@ namespace kernel { * } */ std::map>> AicpuOpAttrToInputMap = { - {prim::kPrimOneHot->name(), {{"depth", 1}}}}; + {prim::kPrimOneHot->name(), {{"depth", 1}}}, {prim::kPrimConcat->name(), {{"axis", 0}}}}; bool GetAicpuOpAttrToInputInfo(const CNodePtr &kernel_node, std::vector> *info) { std::string op_name = common::AnfAlgo::GetCNodeName(kernel_node); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/concat_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/concat_cpu_kernel.cc index 217b0dba272..babcb1afef2 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/concat_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/concat_cpu_kernel.cc @@ -97,6 +97,8 @@ bool ConcatCpuKernelMod::LaunchKernel(const std::vector &inp } std::vector> ConcatCpuKernelMod::func_list_ = { + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &ConcatCpuKernelMod::LaunchKernel}, {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), &ConcatCpuKernelMod::LaunchKernel}, {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), @@ -117,6 +119,10 @@ std::vector> ConcatCpuKern &ConcatCpuKernelMod::LaunchKernel}, {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), &ConcatCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), + &ConcatCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), + &ConcatCpuKernelMod::LaunchKernel}, {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), &ConcatCpuKernelMod::LaunchKernel}}; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/concat_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/concat_cpu_kernel.h index e8a63e0a73a..a9ddbe2d037 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/concat_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/concat_cpu_kernel.h @@ -20,12 +20,16 @@ #include #include #include +#include #include "plugin/device/cpu/kernel/cpu_kernel.h" #include "plugin/factory/ms_factory.h" namespace mindspore { namespace kernel { +using complex64 = std::complex; +using complex128 = std::complex; + class ConcatCpuKernelMod : public NativeCpuKernelMod { public: ConcatCpuKernelMod() = default; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 498245246ca..4f9d7ed9bf1 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -120,6 +120,7 @@ constexpr auto kZerosLike = "ZerosLike"; constexpr auto kOnes = "Ones"; constexpr auto kOnesLike = "OnesLike"; constexpr auto kIdentity = "Identity"; +constexpr auto kConcat = "Concat"; constexpr auto kDiag = "Diag"; constexpr auto kDiagPart = "DiagPart"; constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs"; @@ -265,7 +266,7 @@ GVAR_DEF(PrimitivePtr, kPrimBroadcastShape, std::make_shared("broadca GVAR_DEF(PrimitivePtr, kPrimArrayMap, std::make_shared("array_map")); GVAR_DEF(PrimitivePtr, kPrimArrayReduce, std::make_shared("array_reduce")); GVAR_DEF(PrimitivePtr, kPrimCast, std::make_shared("Cast")); -GVAR_DEF(PrimitivePtr, kPrimConcat, std::make_shared("Concat")); +GVAR_DEF(PrimitivePtr, kPrimConcat, std::make_shared(kConcat)); GVAR_DEF(PrimitivePtr, kPrimSqueeze, std::make_shared("Squeeze")); GVAR_DEF(PrimitivePtr, kPrimUnsqueeze, std::make_shared("Unsqueeze")); GVAR_DEF(PrimitivePtr, kPrimTranspose, std::make_shared(kTranspose)); diff --git a/mindspore/core/ops/concat.cc b/mindspore/core/ops/concat.cc index c77e08fab64..901d0500c08 100644 --- a/mindspore/core/ops/concat.cc +++ b/mindspore/core/ops/concat.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,19 +19,106 @@ #include "ops/concat.h" #include "ops/op_utils.h" #include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" #include "mindapi/src/helper.h" namespace mindspore { namespace ops { -MIND_API_BASE_IMPL(Concat, PrimitiveC, BaseOperator); +namespace { +abstract::ShapePtr ConcatInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + const int64_t kOneNum = 1; + auto x_shape_ptr = input_args[0]->isa() + ? input_args[0]->cast()->BuildShape() + : input_args[0]->cast()->BuildShape(); + auto elements = input_args[0]->isa() + ? input_args[0]->cast()->elements() + : input_args[0]->cast()->elements(); + (void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, kOneNum, + prim_name); + (void)primitive->AddAttr("N", MakeValue(SizeToLong(elements.size()))); + (void)primitive->AddAttr("inputNums", MakeValue(SizeToLong(elements.size()))); + auto element0 = elements[0]->cast(); + MS_EXCEPTION_IF_NULL(element0); + auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape]; + auto element0_rank = element0_shape.size(); + auto axis_temp = GetValue(primitive->GetAttr(kAxis)); + CheckAndConvertUtils::CheckInRange("Concat axis", axis_temp, kIncludeBoth, + {-SizeToLong(element0_rank), SizeToLong(element0_rank) - kOneNum}, + prim_name); + auto axis = axis_temp < 0 ? LongToSize(axis_temp + element0_rank) : LongToSize(axis_temp); + int64_t all_shp = element0_shape[axis]; + for (size_t i = 1; i < elements.size(); ++i) { + std::string elementi = "element" + std::to_string(i); + auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape]; + (void)CheckAndConvertUtils::CheckInteger(elementi + " shape rank", SizeToLong(elementi_shape.size()), kEqual, + SizeToLong(element0_shape.size()), prim_name); + for (size_t j = 0; j < element0_rank; ++j) { + if (j != axis && elementi_shape[j] != element0_shape[j]) { + MS_LOG(EXCEPTION) << "For '" << prim_name << "', element " << i + << " shape in input should concat with first element, but it can not."; + } + } + all_shp = all_shp == -1 || elementi_shape[axis] == -1 ? -1 : all_shp + elementi_shape[axis]; + } + auto ret_shape = element0_shape; + ret_shape[axis] = all_shp; + if (x_shape_ptr->IsDynamic()) { + auto element0_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMaxShape]; + auto element0_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMinShape]; + auto ret_max_shape = element0_max_shape; + auto ret_min_shape = element0_min_shape; + for (size_t i = 1; i < elements.size(); ++i) { + auto elementi_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMaxShape]; + auto elementi_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMinShape]; + ret_max_shape[axis] += elementi_max_shape[axis]; + ret_min_shape[axis] += elementi_min_shape[axis]; + } + return std::make_shared(ret_shape, ret_min_shape, ret_max_shape); + } else { + return std::make_shared(ret_shape); + } +} + +TypePtr ConcatInferType(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + if (!input_args[0]->isa() && !input_args[0]->isa()) { + MS_EXCEPTION(TypeError) << "The input of Concat must be list or tuple of tensors."; + } + auto elements = input_args[0]->isa() + ? input_args[0]->cast()->elements() + : input_args[0]->cast()->elements(); + std::map types; + for (size_t i = 0; i < elements.size(); ++i) { + std::string elementi = "element" + std::to_string(i); + (void)types.emplace(elementi, elements[i]->BuildType()); + } + (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex_and_bool, prim_name); + return elements[0]->BuildType(); +} +} // namespace + void Concat::Init(const int64_t axis) { this->set_axis(axis); } int64_t Concat::get_axis() const { auto value_ptr = this->GetAttr(kAxis); return GetValue(value_ptr); } - void Concat::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); } -REGISTER_PRIMITIVE_C(kNameConcat, Concat); +MIND_API_BASE_IMPL(Concat, PrimitiveC, BaseOperator); +AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + const int64_t kInputNum = 1; + CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name()); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto infer_type = ConcatInferType(primitive, input_args); + auto infer_shape = ConcatInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(Concat, prim::kPrimConcat, ConcatInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/concat.h b/mindspore/core/ops/concat.h index 5da65e12011..42d95a1096b 100644 --- a/mindspore/core/ops/concat.h +++ b/mindspore/core/ops/concat.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ class MIND_API Concat : public BaseOperator { public: MIND_API_BASE_MEMBER(Concat); /// \brief Constructor. - Concat() : BaseOperator(kNameConcat) {} + Concat() : BaseOperator(kNameConcat) { InitIOName({"x"}, {"y"}); } /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Concat for the inputs. void Init(const int64_t axis = 0); /// \brief Set axis. diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index 2f0c7bfd62a..2c055696f14 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -116,6 +116,7 @@ from .lower_bound import _lower_bound_aicpu from .upper_bound import _upper_bound_aicpu from .zeros_like import _zeros_like_aicpu from .ones_like import _ones_like_aicpu +from .concat import _concat_aicpu from .grid_sampler_3d import _grid_sampler_3d_aicpu from .grid_sampler_3d_grad import _grid_sampler_3d_grad_aicpu from .environ_create import _environ_create_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/concat.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/concat.py new file mode 100644 index 00000000000..acbd13f56e0 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/concat.py @@ -0,0 +1,57 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Concat op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +concat_op_info = AiCPURegOp("Concat") \ + .fusion_type("OPAQUE") \ + .input(0, "concat_dim", "required") \ + .input(1, "x", "dynamic") \ + .output(0, "y", "required") \ + .attr("N", "int") \ + .dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I32_Default, DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.I32_Default, DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.I32_Default, DataType.C64_Default, DataType.C64_Default) \ + .dtype_format(DataType.I32_Default, DataType.C128_Default, DataType.C128_Default) \ + .dtype_format(DataType.I64_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I64_Default, DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.I64_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I64_Default, DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.I64_Default, DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.I64_Default, DataType.C64_Default, DataType.C64_Default) \ + .dtype_format(DataType.I64_Default, DataType.C128_Default, DataType.C128_Default) \ + .get_op_info() + + +@op_info_register(concat_op_info) +def _concat_aicpu(): + """Concat AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 542351c5c45..4b62bfdd84c 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -28,7 +28,6 @@ from mindspore import context from mindspore.common.initializer import Zero from .. import signature as sig from .._utils import get_broadcast_shape, is_shape_unknown -from .._utils import get_concat_offset from ..operations.math_ops import _infer_shape_reduce from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op from ..._checkparam import Rel @@ -2567,7 +2566,7 @@ class UnsortedSegmentProd(PrimitiveWithInfer): return out -class Concat(PrimitiveWithInfer): +class Concat(Primitive): r""" Connect tensor in the specified axis. @@ -2601,6 +2600,10 @@ class Concat(PrimitiveWithInfer): Raises: TypeError: If `axis` is not an int. + TypeError: If `input_x` have different type of tensor. + ValueError: If `input_x` have different dimension of tensor. + ValueError: If `axis` not in [-dims, dims - 1]. + RuntimeError: If tensor's shape in `input_x` except for `axis` are different. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -2627,35 +2630,6 @@ class Concat(PrimitiveWithInfer): """Initialize Concat""" validator.check_value_type("axis", axis, [int], self.name) - def __infer__(self, input_x): - axis = self.axis - x_shp = input_x['shape'] - x_type = input_x['dtype'] - _, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name) - self.add_prim_attr('inputNums', len(x_shp)) - ret_shp = x_shp[0].copy() - value = None - if input_x['value'] is not None: - value = Tensor(np.concatenate([x.asnumpy() for x in input_x['value']], axis=axis)) - ret_shp[axis] = all_shp - out = {'shape': ret_shp, - 'dtype': x_type[0], - 'value': value} - if -1 in x_shp[0]: - x_min_shp = input_x['min_shape'] - ret_min_shp = x_min_shp[0].copy() - ret_min_shp[axis] = 0 - for all_min_shp in x_min_shp: - ret_min_shp[axis] += all_min_shp[axis] - out['min_shape'] = ret_min_shp - x_max_shp = input_x['max_shape'] - ret_max_shp = x_max_shp[0].copy() - ret_max_shp[axis] = 0 - for all_max_shp in x_max_shp: - ret_max_shp[axis] += all_max_shp[axis] - out['max_shape'] = ret_max_shp - return out - class ParallelConcat(PrimitiveWithInfer): r"""