!32994 Add inner operator FlattenConcat

Merge pull request !32994 from hewei/flatten_weights
This commit is contained in:
i-robot 2022-04-15 07:39:59 +00:00 committed by Gitee
commit 12d28e25e2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 122 additions and 19 deletions

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -282,6 +282,8 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConcatOffset(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFlattenConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,13 +18,14 @@
#include <functional>
#include <iterator>
#include <numeric>
#include "abstract/abstract_value.h"
#include "abstract/ops/infer_functions.h"
#include "abstract/utils.h"
#include "abstract/param_validator.h"
#include "utils/shape_utils.h"
#include "abstract/utils.h"
#include "ops/op_utils.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/shape_utils.h"
namespace mindspore {
namespace abstract {
@ -1126,6 +1127,54 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p
return ret;
}
AbstractBasePtr InferImplFlattenConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
CheckArgsSize(primitive->name(), args_spec_list, 1);
auto seq = dyn_cast<abstract::AbstractSequence>(args_spec_list[0]);
if (seq == nullptr) {
MS_LOG(EXCEPTION) << "The input for '" << primitive->name() << "' should be tuple or list, but got "
<< args_spec_list[0]->type_name();
}
// Group inputs by data type and calculate their chunk sizes.
std::map<TypeId, size_t> chunks;
for (auto &element : seq->elements()) {
auto abs_tensor = dyn_cast<abstract::AbstractTensor>(element);
if (abs_tensor == nullptr) {
MS_LOG(EXCEPTION) << "The input element for '" << primitive->name() << "' should be Tensor, but got "
<< element->type_name();
}
// Calculate data size (number of elements) by shape.
auto base_shape = abs_tensor->BuildShape();
MS_EXCEPTION_IF_NULL(base_shape);
auto shape = base_shape->cast<ShapePtr>();
if (shape == nullptr) {
MS_LOG(EXCEPTION) << "The input tensors for '" << primitive->name() << "' should have shape, but got "
<< base_shape->ToString();
}
auto data_size = SizeOf(shape->shape());
if (data_size == 0) {
MS_LOG(EXCEPTION) << "The input tensors for '" << primitive->name() << "'should have static shape, but got "
<< shape->ToString();
}
// Find data type from the AbstractTensor.
const auto &element_abs = abs_tensor->element();
MS_EXCEPTION_IF_NULL(element_abs);
auto dtype = element_abs->BuildType();
MS_EXCEPTION_IF_NULL(dtype);
// Group them by data type.
chunks[dtype->type_id()] += data_size;
}
// Make result AbstractTuple.
AbstractBasePtrList tuple_element;
tuple_element.reserve(chunks.size());
for (auto &chunk : chunks) {
ShapeVector shape_vec{static_cast<int64_t>(chunk.second)};
auto abs = std::make_shared<abstract::AbstractTensor>(TypeIdToType(chunk.first), shape_vec);
(void)tuple_element.emplace_back(abs);
}
return std::make_shared<abstract::AbstractTuple>(std::move(tuple_element));
}
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();

View File

@ -182,6 +182,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimMaskedSelect, R{InferImplMaskedSelect, nullptr, true}},
{prim::kPrimTensorCopySlices, R{InferImplTensorCopySlices, nullptr, true}},
{prim::kPrimNonZero, R{InferImplNonZero, nullptr, true}},
{prim::kPrimFlattenConcat, R{InferImplFlattenConcat, nullptr, true}},
// Structure
{prim::kPrimMakeTuple, R{InferImplMakeTuple, nullptr, true}},
{prim::kPrimMakeList, R{InferImplMakeList, nullptr, true}},

View File

@ -30,6 +30,7 @@
#include "base/complex_storage.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils_secure.h"
#include "utils/shape_utils.h"
namespace mindspore {
namespace tensor {
@ -50,18 +51,6 @@ static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) {
return data_type ? data_type->type_id() : defaultTypeId;
}
static size_t SizeOf(const ShapeVector &shape) {
int64_t data_size = 1;
for (auto dim : shape) {
if (dim < 0) {
// For dynamic shape which has negative dimensions, data size should be zero.
return 0;
}
data_size *= dim;
}
return static_cast<size_t>(data_size);
}
static std::string ShapeToString(const ShapeVector &shape) {
std::string str = "[";
const size_t count = shape.size();

View File

@ -122,6 +122,7 @@ constexpr auto kOnes = "Ones";
constexpr auto kOnesLike = "OnesLike";
constexpr auto kIdentity = "Identity";
constexpr auto kConcat = "Concat";
constexpr auto kFlattenConcat = "FlattenConcat";
constexpr auto kRightShift = "RightShift";
constexpr auto kDiag = "Diag";
constexpr auto kDiagPart = "DiagPart";
@ -279,6 +280,7 @@ GVAR_DEF(PrimitivePtr, kPrimArrayMap, std::make_shared<Primitive>("array_map"));
GVAR_DEF(PrimitivePtr, kPrimArrayReduce, std::make_shared<Primitive>("array_reduce"));
GVAR_DEF(PrimitivePtr, kPrimCast, std::make_shared<Primitive>("Cast"));
GVAR_DEF(PrimitivePtr, kPrimConcat, std::make_shared<Primitive>(kConcat));
GVAR_DEF(PrimitivePtr, kPrimFlattenConcat, std::make_shared<Primitive>(kFlattenConcat));
GVAR_DEF(PrimitivePtr, kPrimSqueeze, std::make_shared<Primitive>("Squeeze"));
GVAR_DEF(PrimitivePtr, kPrimUnsqueeze, std::make_shared<Primitive>("Unsqueeze"));
GVAR_DEF(PrimitivePtr, kPrimTranspose, std::make_shared<Primitive>(kTranspose));

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,4 +19,18 @@
#include "mindapi/base/shape_vector.h"
namespace mindspore {
inline size_t SizeOf(const ShapeVector &shape) {
int64_t data_size = 1;
for (auto dim : shape) {
if (dim < 0) {
// For dynamic shape which has negative dimensions, data size should be zero.
return 0;
}
data_size *= dim;
}
return static_cast<size_t>(data_size);
}
} // namespace mindspore
#endif // MINDSPORE_SHAPE_UTILS_INFO_H_

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -1911,3 +1911,35 @@ class Format(PrimitiveWithInfer):
var_value.append(item["value"])
value = str_value.format(*var_value)
return {'dtype': mstype.string, 'shape': [], 'value': value}
class FlattenConcat(Primitive):
"""
Flatten input tensors and concatenate them into several chunk tensors grouped by data types.
Inputs:
- **tensors** (tuple[Tensor], list[Tensor]) - The input Tensors to be flattened and concatenated.
Outputs:
tuple[Tensor], result chunk tensors.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore.ops.operations import _inner_ops as inner
>>> t1 = Tensor(np.array([1]).astype(np.float32))
>>> t2 = Tensor(np.array([2]).astype(np.float32))
>>> t3 = Tensor(np.array([3]).astype(np.float64))
>>> t4 = Tensor(np.array([4]).astype(np.float32))
>>> t5 = Tensor(np.array([5]).astype(np.float64))
>>> chunks = inner.FlattenConcat()([t1, t2, t2, t3, t4, t5])
>>> print(chunks[0].asnumpy())
>>> print(chunks[1].asnumpy())
[1. 2. 4.]
[3. 5.]
"""
@prim_attr_register
def __init__(self):
"""Initialize FlattenConcat"""

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -359,6 +359,15 @@ class RangeNet(Cell):
return self.range_ops(x)
class NetForFlattenConcat(Cell):
def __init__(self):
super(NetForFlattenConcat, self).__init__()
self.flatten_concat = inner.FlattenConcat()
def construct(self, x1, x2, x3):
return self.flatten_concat([x1, x2, x3])
test_case_array_ops = [
('CustNet1', {
'block': CustNet1(),
@ -408,6 +417,11 @@ test_case_array_ops = [
('RangeNet', {
'block': RangeNet(),
'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]}),
('FlattenConcat', {
'block': NetForFlattenConcat(),
'desc_inputs': [Tensor(np.array([1], np.float32)),
Tensor(np.array([2], np.float32)),
Tensor(np.array([3], np.float64))]}),
('TensorShapeNet', {'block': TensorShapeNet(), 'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]})
]