From 9a0a0ad5e857e8914490f836c301dd938aa4fa78 Mon Sep 17 00:00:00 2001 From: He Wei Date: Mon, 11 Apr 2022 20:15:04 +0800 Subject: [PATCH] Add inner operator FlattenConcat --- mindspore/core/abstract/ops/infer_functions.h | 4 +- mindspore/core/abstract/ops/prim_arrays.cc | 55 ++++++++++++++++++- .../core/abstract/ops/primitive_infer_map.cc | 1 + mindspore/core/ir/tensor.cc | 13 +---- mindspore/core/ops/core_ops.h | 2 + mindspore/core/utils/shape_utils.h | 16 +++++- .../mindspore/ops/operations/_inner_ops.py | 34 +++++++++++- tests/ut/python/ops/test_array_ops.py | 16 +++++- 8 files changed, 122 insertions(+), 19 deletions(-) diff --git a/mindspore/core/abstract/ops/infer_functions.h b/mindspore/core/abstract/ops/infer_functions.h index e4a644b6931..704f766d1dd 100644 --- a/mindspore/core/abstract/ops/infer_functions.h +++ b/mindspore/core/abstract/ops/infer_functions.h @@ -1,7 +1,7 @@ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * - * Copyright 2019-2021 Huawei Technologies Co., Ltd + * Copyright 2019-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -282,6 +282,8 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplConcatOffset(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplFlattenConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/ops/prim_arrays.cc b/mindspore/core/abstract/ops/prim_arrays.cc index 709324f8741..6c62c7f6d6b 100644 --- a/mindspore/core/abstract/ops/prim_arrays.cc +++ b/mindspore/core/abstract/ops/prim_arrays.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2021-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,13 +18,14 @@ #include #include #include +#include "abstract/abstract_value.h" #include "abstract/ops/infer_functions.h" -#include "abstract/utils.h" #include "abstract/param_validator.h" -#include "utils/shape_utils.h" +#include "abstract/utils.h" #include "ops/op_utils.h" #include "utils/anf_utils.h" #include "utils/check_convert_utils.h" +#include "utils/shape_utils.h" namespace mindspore { namespace abstract { @@ -1126,6 +1127,54 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p return ret; } +AbstractBasePtr InferImplFlattenConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + CheckArgsSize(primitive->name(), args_spec_list, 1); + auto seq = dyn_cast(args_spec_list[0]); + if (seq == nullptr) { + MS_LOG(EXCEPTION) << "The input for '" << primitive->name() << "' should be tuple or list, but got " + << args_spec_list[0]->type_name(); + } + // Group inputs by data type and calculate their chunk sizes. + std::map chunks; + for (auto &element : seq->elements()) { + auto abs_tensor = dyn_cast(element); + if (abs_tensor == nullptr) { + MS_LOG(EXCEPTION) << "The input element for '" << primitive->name() << "' should be Tensor, but got " + << element->type_name(); + } + // Calculate data size (number of elements) by shape. + auto base_shape = abs_tensor->BuildShape(); + MS_EXCEPTION_IF_NULL(base_shape); + auto shape = base_shape->cast(); + if (shape == nullptr) { + MS_LOG(EXCEPTION) << "The input tensors for '" << primitive->name() << "' should have shape, but got " + << base_shape->ToString(); + } + auto data_size = SizeOf(shape->shape()); + if (data_size == 0) { + MS_LOG(EXCEPTION) << "The input tensors for '" << primitive->name() << "'should have static shape, but got " + << shape->ToString(); + } + // Find data type from the AbstractTensor. + const auto &element_abs = abs_tensor->element(); + MS_EXCEPTION_IF_NULL(element_abs); + auto dtype = element_abs->BuildType(); + MS_EXCEPTION_IF_NULL(dtype); + // Group them by data type. + chunks[dtype->type_id()] += data_size; + } + // Make result AbstractTuple. + AbstractBasePtrList tuple_element; + tuple_element.reserve(chunks.size()); + for (auto &chunk : chunks) { + ShapeVector shape_vec{static_cast(chunk.second)}; + auto abs = std::make_shared(TypeIdToType(chunk.first), shape_vec); + (void)tuple_element.emplace_back(abs); + } + return std::make_shared(std::move(tuple_element)); +} + AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string &op_name = primitive->name(); diff --git a/mindspore/core/abstract/ops/primitive_infer_map.cc b/mindspore/core/abstract/ops/primitive_infer_map.cc index 34db6523065..69e225eb4bd 100644 --- a/mindspore/core/abstract/ops/primitive_infer_map.cc +++ b/mindspore/core/abstract/ops/primitive_infer_map.cc @@ -182,6 +182,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimMaskedSelect, R{InferImplMaskedSelect, nullptr, true}}, {prim::kPrimTensorCopySlices, R{InferImplTensorCopySlices, nullptr, true}}, {prim::kPrimNonZero, R{InferImplNonZero, nullptr, true}}, + {prim::kPrimFlattenConcat, R{InferImplFlattenConcat, nullptr, true}}, // Structure {prim::kPrimMakeTuple, R{InferImplMakeTuple, nullptr, true}}, {prim::kPrimMakeList, R{InferImplMakeList, nullptr, true}}, diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index e5e76c3b16d..0a18e118584 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -30,6 +30,7 @@ #include "base/complex_storage.h" #include "utils/log_adapter.h" #include "utils/ms_utils_secure.h" +#include "utils/shape_utils.h" namespace mindspore { namespace tensor { @@ -50,18 +51,6 @@ static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) { return data_type ? data_type->type_id() : defaultTypeId; } -static size_t SizeOf(const ShapeVector &shape) { - int64_t data_size = 1; - for (auto dim : shape) { - if (dim < 0) { - // For dynamic shape which has negative dimensions, data size should be zero. - return 0; - } - data_size *= dim; - } - return static_cast(data_size); -} - static std::string ShapeToString(const ShapeVector &shape) { std::string str = "["; const size_t count = shape.size(); diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index b00f7e62148..640d645deb4 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -122,6 +122,7 @@ constexpr auto kOnes = "Ones"; constexpr auto kOnesLike = "OnesLike"; constexpr auto kIdentity = "Identity"; constexpr auto kConcat = "Concat"; +constexpr auto kFlattenConcat = "FlattenConcat"; constexpr auto kRightShift = "RightShift"; constexpr auto kDiag = "Diag"; constexpr auto kDiagPart = "DiagPart"; @@ -279,6 +280,7 @@ GVAR_DEF(PrimitivePtr, kPrimArrayMap, std::make_shared("array_map")); GVAR_DEF(PrimitivePtr, kPrimArrayReduce, std::make_shared("array_reduce")); GVAR_DEF(PrimitivePtr, kPrimCast, std::make_shared("Cast")); GVAR_DEF(PrimitivePtr, kPrimConcat, std::make_shared(kConcat)); +GVAR_DEF(PrimitivePtr, kPrimFlattenConcat, std::make_shared(kFlattenConcat)); GVAR_DEF(PrimitivePtr, kPrimSqueeze, std::make_shared("Squeeze")); GVAR_DEF(PrimitivePtr, kPrimUnsqueeze, std::make_shared("Unsqueeze")); GVAR_DEF(PrimitivePtr, kPrimTranspose, std::make_shared(kTranspose)); diff --git a/mindspore/core/utils/shape_utils.h b/mindspore/core/utils/shape_utils.h index ca0fa5412e3..125849a2416 100644 --- a/mindspore/core/utils/shape_utils.h +++ b/mindspore/core/utils/shape_utils.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2021-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,4 +19,18 @@ #include "mindapi/base/shape_vector.h" +namespace mindspore { +inline size_t SizeOf(const ShapeVector &shape) { + int64_t data_size = 1; + for (auto dim : shape) { + if (dim < 0) { + // For dynamic shape which has negative dimensions, data size should be zero. + return 0; + } + data_size *= dim; + } + return static_cast(data_size); +} +} // namespace mindspore + #endif // MINDSPORE_SHAPE_UTILS_INFO_H_ diff --git a/mindspore/python/mindspore/ops/operations/_inner_ops.py b/mindspore/python/mindspore/ops/operations/_inner_ops.py index 52c3596a10f..058b8de4225 100755 --- a/mindspore/python/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/python/mindspore/ops/operations/_inner_ops.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1911,3 +1911,35 @@ class Format(PrimitiveWithInfer): var_value.append(item["value"]) value = str_value.format(*var_value) return {'dtype': mstype.string, 'shape': [], 'value': value} + + +class FlattenConcat(Primitive): + """ + Flatten input tensors and concatenate them into several chunk tensors grouped by data types. + + Inputs: + - **tensors** (tuple[Tensor], list[Tensor]) - The input Tensors to be flattened and concatenated. + + Outputs: + tuple[Tensor], result chunk tensors. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> from mindspore.ops.operations import _inner_ops as inner + >>> t1 = Tensor(np.array([1]).astype(np.float32)) + >>> t2 = Tensor(np.array([2]).astype(np.float32)) + >>> t3 = Tensor(np.array([3]).astype(np.float64)) + >>> t4 = Tensor(np.array([4]).astype(np.float32)) + >>> t5 = Tensor(np.array([5]).astype(np.float64)) + >>> chunks = inner.FlattenConcat()([t1, t2, t2, t3, t4, t5]) + >>> print(chunks[0].asnumpy()) + >>> print(chunks[1].asnumpy()) + [1. 2. 4.] + [3. 5.] + """ + + @prim_attr_register + def __init__(self): + """Initialize FlattenConcat""" diff --git a/tests/ut/python/ops/test_array_ops.py b/tests/ut/python/ops/test_array_ops.py index 9fa0c7d3a98..2cef13351e9 100644 --- a/tests/ut/python/ops/test_array_ops.py +++ b/tests/ut/python/ops/test_array_ops.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -359,6 +359,15 @@ class RangeNet(Cell): return self.range_ops(x) +class NetForFlattenConcat(Cell): + def __init__(self): + super(NetForFlattenConcat, self).__init__() + self.flatten_concat = inner.FlattenConcat() + + def construct(self, x1, x2, x3): + return self.flatten_concat([x1, x2, x3]) + + test_case_array_ops = [ ('CustNet1', { 'block': CustNet1(), @@ -408,6 +417,11 @@ test_case_array_ops = [ ('RangeNet', { 'block': RangeNet(), 'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]}), + ('FlattenConcat', { + 'block': NetForFlattenConcat(), + 'desc_inputs': [Tensor(np.array([1], np.float32)), + Tensor(np.array([2], np.float32)), + Tensor(np.array([3], np.float64))]}), ('TensorShapeNet', {'block': TensorShapeNet(), 'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]}) ]