forked from mindspore-Ecosystem/mindspore
Add inner operator FlattenConcat
This commit is contained in:
parent
40fb67383f
commit
9a0a0ad5e8
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -282,6 +282,8 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplConcatOffset(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFlattenConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,13 +18,14 @@
|
|||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "abstract/ops/infer_functions.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/shape_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -1126,6 +1127,54 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplFlattenConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 1);
|
||||
auto seq = dyn_cast<abstract::AbstractSequence>(args_spec_list[0]);
|
||||
if (seq == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The input for '" << primitive->name() << "' should be tuple or list, but got "
|
||||
<< args_spec_list[0]->type_name();
|
||||
}
|
||||
// Group inputs by data type and calculate their chunk sizes.
|
||||
std::map<TypeId, size_t> chunks;
|
||||
for (auto &element : seq->elements()) {
|
||||
auto abs_tensor = dyn_cast<abstract::AbstractTensor>(element);
|
||||
if (abs_tensor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The input element for '" << primitive->name() << "' should be Tensor, but got "
|
||||
<< element->type_name();
|
||||
}
|
||||
// Calculate data size (number of elements) by shape.
|
||||
auto base_shape = abs_tensor->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(base_shape);
|
||||
auto shape = base_shape->cast<ShapePtr>();
|
||||
if (shape == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The input tensors for '" << primitive->name() << "' should have shape, but got "
|
||||
<< base_shape->ToString();
|
||||
}
|
||||
auto data_size = SizeOf(shape->shape());
|
||||
if (data_size == 0) {
|
||||
MS_LOG(EXCEPTION) << "The input tensors for '" << primitive->name() << "'should have static shape, but got "
|
||||
<< shape->ToString();
|
||||
}
|
||||
// Find data type from the AbstractTensor.
|
||||
const auto &element_abs = abs_tensor->element();
|
||||
MS_EXCEPTION_IF_NULL(element_abs);
|
||||
auto dtype = element_abs->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(dtype);
|
||||
// Group them by data type.
|
||||
chunks[dtype->type_id()] += data_size;
|
||||
}
|
||||
// Make result AbstractTuple.
|
||||
AbstractBasePtrList tuple_element;
|
||||
tuple_element.reserve(chunks.size());
|
||||
for (auto &chunk : chunks) {
|
||||
ShapeVector shape_vec{static_cast<int64_t>(chunk.second)};
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(TypeIdToType(chunk.first), shape_vec);
|
||||
(void)tuple_element.emplace_back(abs);
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTuple>(std::move(tuple_element));
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string &op_name = primitive->name();
|
||||
|
|
|
@ -182,6 +182,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimMaskedSelect, R{InferImplMaskedSelect, nullptr, true}},
|
||||
{prim::kPrimTensorCopySlices, R{InferImplTensorCopySlices, nullptr, true}},
|
||||
{prim::kPrimNonZero, R{InferImplNonZero, nullptr, true}},
|
||||
{prim::kPrimFlattenConcat, R{InferImplFlattenConcat, nullptr, true}},
|
||||
// Structure
|
||||
{prim::kPrimMakeTuple, R{InferImplMakeTuple, nullptr, true}},
|
||||
{prim::kPrimMakeList, R{InferImplMakeList, nullptr, true}},
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "base/complex_storage.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils_secure.h"
|
||||
#include "utils/shape_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace tensor {
|
||||
|
@ -50,18 +51,6 @@ static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) {
|
|||
return data_type ? data_type->type_id() : defaultTypeId;
|
||||
}
|
||||
|
||||
static size_t SizeOf(const ShapeVector &shape) {
|
||||
int64_t data_size = 1;
|
||||
for (auto dim : shape) {
|
||||
if (dim < 0) {
|
||||
// For dynamic shape which has negative dimensions, data size should be zero.
|
||||
return 0;
|
||||
}
|
||||
data_size *= dim;
|
||||
}
|
||||
return static_cast<size_t>(data_size);
|
||||
}
|
||||
|
||||
static std::string ShapeToString(const ShapeVector &shape) {
|
||||
std::string str = "[";
|
||||
const size_t count = shape.size();
|
||||
|
|
|
@ -122,6 +122,7 @@ constexpr auto kOnes = "Ones";
|
|||
constexpr auto kOnesLike = "OnesLike";
|
||||
constexpr auto kIdentity = "Identity";
|
||||
constexpr auto kConcat = "Concat";
|
||||
constexpr auto kFlattenConcat = "FlattenConcat";
|
||||
constexpr auto kRightShift = "RightShift";
|
||||
constexpr auto kDiag = "Diag";
|
||||
constexpr auto kDiagPart = "DiagPart";
|
||||
|
@ -279,6 +280,7 @@ GVAR_DEF(PrimitivePtr, kPrimArrayMap, std::make_shared<Primitive>("array_map"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimArrayReduce, std::make_shared<Primitive>("array_reduce"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCast, std::make_shared<Primitive>("Cast"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimConcat, std::make_shared<Primitive>(kConcat));
|
||||
GVAR_DEF(PrimitivePtr, kPrimFlattenConcat, std::make_shared<Primitive>(kFlattenConcat));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSqueeze, std::make_shared<Primitive>("Squeeze"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimUnsqueeze, std::make_shared<Primitive>("Unsqueeze"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTranspose, std::make_shared<Primitive>(kTranspose));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,4 +19,18 @@
|
|||
|
||||
#include "mindapi/base/shape_vector.h"
|
||||
|
||||
namespace mindspore {
|
||||
inline size_t SizeOf(const ShapeVector &shape) {
|
||||
int64_t data_size = 1;
|
||||
for (auto dim : shape) {
|
||||
if (dim < 0) {
|
||||
// For dynamic shape which has negative dimensions, data size should be zero.
|
||||
return 0;
|
||||
}
|
||||
data_size *= dim;
|
||||
}
|
||||
return static_cast<size_t>(data_size);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_SHAPE_UTILS_INFO_H_
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1911,3 +1911,35 @@ class Format(PrimitiveWithInfer):
|
|||
var_value.append(item["value"])
|
||||
value = str_value.format(*var_value)
|
||||
return {'dtype': mstype.string, 'shape': [], 'value': value}
|
||||
|
||||
|
||||
class FlattenConcat(Primitive):
|
||||
"""
|
||||
Flatten input tensors and concatenate them into several chunk tensors grouped by data types.
|
||||
|
||||
Inputs:
|
||||
- **tensors** (tuple[Tensor], list[Tensor]) - The input Tensors to be flattened and concatenated.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], result chunk tensors.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops.operations import _inner_ops as inner
|
||||
>>> t1 = Tensor(np.array([1]).astype(np.float32))
|
||||
>>> t2 = Tensor(np.array([2]).astype(np.float32))
|
||||
>>> t3 = Tensor(np.array([3]).astype(np.float64))
|
||||
>>> t4 = Tensor(np.array([4]).astype(np.float32))
|
||||
>>> t5 = Tensor(np.array([5]).astype(np.float64))
|
||||
>>> chunks = inner.FlattenConcat()([t1, t2, t2, t3, t4, t5])
|
||||
>>> print(chunks[0].asnumpy())
|
||||
>>> print(chunks[1].asnumpy())
|
||||
[1. 2. 4.]
|
||||
[3. 5.]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize FlattenConcat"""
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -359,6 +359,15 @@ class RangeNet(Cell):
|
|||
return self.range_ops(x)
|
||||
|
||||
|
||||
class NetForFlattenConcat(Cell):
|
||||
def __init__(self):
|
||||
super(NetForFlattenConcat, self).__init__()
|
||||
self.flatten_concat = inner.FlattenConcat()
|
||||
|
||||
def construct(self, x1, x2, x3):
|
||||
return self.flatten_concat([x1, x2, x3])
|
||||
|
||||
|
||||
test_case_array_ops = [
|
||||
('CustNet1', {
|
||||
'block': CustNet1(),
|
||||
|
@ -408,6 +417,11 @@ test_case_array_ops = [
|
|||
('RangeNet', {
|
||||
'block': RangeNet(),
|
||||
'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]}),
|
||||
('FlattenConcat', {
|
||||
'block': NetForFlattenConcat(),
|
||||
'desc_inputs': [Tensor(np.array([1], np.float32)),
|
||||
Tensor(np.array([2], np.float32)),
|
||||
Tensor(np.array([3], np.float64))]}),
|
||||
('TensorShapeNet', {'block': TensorShapeNet(), 'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]})
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue