forked from mindspore-Ecosystem/mindspore
add node output and shape pass
This commit is contained in:
parent
55c71a6034
commit
beeb36d075
|
@ -59,6 +59,9 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/fusion/tf_gelu_fusion.cc
|
||||
../optimizer/fusion/onnx_gelu_fusion.cc
|
||||
../optimizer/fusion/squeeze_fusion.cc
|
||||
../optimizer/fisson/fisson_util.cc
|
||||
../optimizer/fisson/iter_node_outputs.cc
|
||||
../optimizer/fisson/node_out_shapes.cc
|
||||
../optimizer/graph/conv1d_inout_adjust_pass.cc
|
||||
../optimizer/graph/weight_format_transform_pass.cc
|
||||
../optimizer/graph/weight_format_hardcode_pass.cc
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include "tools/optimizer/fisson/fisson_util.h"
|
||||
#include "mindspore/core/base/core_ops.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
using lite::converter::FmkType;
|
||||
|
||||
namespace opt {
|
||||
std::unordered_map<std::string, std::vector<AnfNodePtr>> g_graph_nodes_output = {};
|
||||
std::unordered_map<std::string, std::vector<std::vector<ShapeVector>>> g_graph_nodes_out_shapes = {};
|
||||
|
||||
AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
|
||||
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
|
||||
const std::string &node_name) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(conv_cnode);
|
||||
|
||||
int32_t nodes_num = conv_outputs.size();
|
||||
if (nodes_num != split_info.out_num) {
|
||||
MS_LOG(ERROR) << "Conv outputs has wrong input size";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// the inputs of concate are from the outputs of conv
|
||||
std::vector<AnfNodePtr> concate_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
|
||||
for (int32_t i = 0; i < nodes_num; i++) {
|
||||
concate_inputs.push_back(conv_outputs[i]);
|
||||
}
|
||||
|
||||
auto concate_cnode = func_graph->NewCNode(concate_inputs);
|
||||
MS_EXCEPTION_IF_NULL(concate_cnode);
|
||||
|
||||
concate_cnode->set_fullname_with_scope(node_name + "_Concat");
|
||||
concate_cnode->set_scope(conv_cnode->scope());
|
||||
|
||||
return concate_cnode;
|
||||
}
|
||||
|
||||
int32_t GetCOutAxis(int32_t format) {
|
||||
switch (format) {
|
||||
case schema::Format_KHWC:
|
||||
return 0;
|
||||
case schema::Format_CHWK:
|
||||
return 3;
|
||||
case schema::Format_NCHW:
|
||||
return 0;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Do not support format: " << format << " now.";
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t GetCInAxis(int32_t format) {
|
||||
switch (format) {
|
||||
case schema::Format_KHWC:
|
||||
return 3;
|
||||
case schema::Format_CHWK:
|
||||
return 0;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Do not support format: " << format << " now.";
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t GetAxis(int32_t axis, int32_t format, const SplitInfo &split_info) {
|
||||
switch (split_info.primitive_type) {
|
||||
case mindspore::schema::PrimitiveType_Conv2DFusion:
|
||||
if (axis == CuttingStragedy::CUT_C_OUT) {
|
||||
return GetCOutAxis(format);
|
||||
} else if (axis == CuttingStragedy::CUT_C_IN) {
|
||||
return GetCInAxis(format);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Only channel_in and channel_out need to transform.";
|
||||
}
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Now, do not support the type : " << split_info.primitive_type;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateOutputsOfAddN(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
|
||||
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
|
||||
const std::string &node_name) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(conv_cnode);
|
||||
|
||||
int32_t nodes_num = conv_outputs.size();
|
||||
if (nodes_num != split_info.out_num) {
|
||||
MS_LOG(ERROR) << "Conv outputs has wrong input size";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// the inputs of addn are from the outputs of conv
|
||||
std::vector<AnfNodePtr> addn_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))};
|
||||
for (int32_t i = 0; i < nodes_num; i++) {
|
||||
addn_inputs.push_back(conv_outputs[i]);
|
||||
}
|
||||
|
||||
auto addn_cnode = func_graph->NewCNode(addn_inputs);
|
||||
MS_EXCEPTION_IF_NULL(addn_cnode);
|
||||
|
||||
addn_cnode->set_fullname_with_scope(node_name + "_AddN");
|
||||
addn_cnode->set_scope(conv_cnode->scope());
|
||||
|
||||
return addn_cnode;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_FISSON_UTIL_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_FISSON_UTIL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "mindspore/ccsrc/utils/utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "mindspore/lite/include/context.h"
|
||||
#include "mindspore/lite/include/lite_types.h"
|
||||
|
||||
namespace mindspore {
|
||||
using mindspore::schema::PrimitiveType;
|
||||
namespace opt {
|
||||
extern std::unordered_map<std::string, std::vector<AnfNodePtr>> g_graph_nodes_output;
|
||||
extern std::unordered_map<std::string, std::vector<std::vector<ShapeVector>>> g_graph_nodes_out_shapes;
|
||||
|
||||
struct SplitInfo {
|
||||
int32_t axis;
|
||||
int32_t out_num;
|
||||
std::vector<int32_t> size_splits;
|
||||
std::vector<int32_t> extend_top;
|
||||
std::vector<int32_t> extend_bottom;
|
||||
std::vector<mindspore::lite::DeviceType> dev_types;
|
||||
int32_t in_num_conv;
|
||||
int32_t fmk_type;
|
||||
std::vector<int32_t> weight_channel;
|
||||
PrimitiveType primitive_type;
|
||||
};
|
||||
|
||||
typedef enum { CUT_N, CUT_H, CUT_W, CUT_C_IN, CUT_C_OUT, CUT_NONE } CuttingStragedy;
|
||||
|
||||
void GetMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
|
||||
std::vector<AnfNodePtr> *outputs);
|
||||
|
||||
AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
|
||||
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
|
||||
const std::string &node_name);
|
||||
void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
|
||||
std::vector<AnfNodePtr> *split_outputs, const SplitInfo &split_info,
|
||||
const std::string &node_name);
|
||||
|
||||
void GetCNodeShapeInfo(const FuncGraphPtr &func_graph, int32_t fmk_type);
|
||||
|
||||
AnfNodePtr CreateOutputsOfAddN(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
|
||||
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
|
||||
const std::string &node_name);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_FISSON_UTIL_H_
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/fisson/iter_node_outputs.h"
|
||||
#include <vector>
|
||||
#include "tools/optimizer/fisson/fisson_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
AnfNodePtr IterNodeOutputs::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
|
||||
for (auto input_node : inputs) {
|
||||
if (!utils::isa<CNodePtr>(input_node)) {
|
||||
continue;
|
||||
}
|
||||
auto input_cnode = input_node->cast<CNodePtr>();
|
||||
auto name = input_cnode->fullname_with_scope();
|
||||
if (g_graph_nodes_output.find(name) != g_graph_nodes_output.end()) {
|
||||
std::vector<AnfNodePtr>::iterator it;
|
||||
it = find(g_graph_nodes_output[name].begin(), g_graph_nodes_output[name].end(), node);
|
||||
if (it != g_graph_nodes_output[name].end()) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
g_graph_nodes_output[name].push_back(node);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "mindspore/ccsrc/backend/optimizer/common/node_pass.h"
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ITER_NODE_OUTPUTS_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ITER_NODE_OUTPUTS_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class IterNodeOutputs : public opt::NodePass {
|
||||
public:
|
||||
IterNodeOutputs() : NodePass("iter_node_outputs") {}
|
||||
~IterNodeOutputs() override = default;
|
||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ITER_NODE_OUTPUTS_H_
|
|
@ -0,0 +1,81 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/fisson/node_out_shapes.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "tools/optimizer/fisson/fisson_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
AnfNodePtr NodeOutShapes::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<ShapeVector> input_shapes;
|
||||
std::vector<ShapeVector> output_shapes;
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
// input
|
||||
for (auto input_node : cnode->inputs()) {
|
||||
if (utils::isa<CNodePtr>(input_node) || utils::isa<ParameterPtr>(input_node)) {
|
||||
auto in_shape = input_node->Shape();
|
||||
if (in_shape == nullptr) {
|
||||
MS_LOG(ERROR) << "The shape is null.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
if (utils::isa<abstract::ShapePtr>(in_shape)) {
|
||||
const auto &shape = in_shape->cast<abstract::ShapePtr>()->shape();
|
||||
input_shapes.push_back(shape);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "currently not support tuple";
|
||||
}
|
||||
}
|
||||
}
|
||||
// output
|
||||
auto out_shape = cnode->Shape();
|
||||
if (out_shape == nullptr) {
|
||||
MS_LOG(ERROR) << "The shape is null.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
if (utils::isa<abstract::TupleShapePtr>(out_shape)) {
|
||||
auto shape = out_shape->cast<abstract::TupleShapePtr>();
|
||||
for (size_t i = 0; i < shape->size(); ++i) {
|
||||
const auto &shape_ptr = (*shape)[i];
|
||||
if (!utils::isa<abstract::ShapePtr>(shape_ptr)) {
|
||||
MS_LOG(ERROR) << "shape_ptr is not ShapePtr.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
output_shapes.push_back(shape_ptr->cast<abstract::ShapePtr>()->shape());
|
||||
}
|
||||
} else if (utils::isa<abstract::ShapePtr>(out_shape)) {
|
||||
const auto &shape = out_shape->cast<abstract::ShapePtr>()->shape();
|
||||
output_shapes.push_back(shape);
|
||||
}
|
||||
|
||||
std::string name = cnode->fullname_with_scope();
|
||||
g_graph_nodes_out_shapes[name].push_back(input_shapes);
|
||||
g_graph_nodes_out_shapes[name].push_back(output_shapes);
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "mindspore/ccsrc/backend/optimizer/common/node_pass.h"
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_NODE_OUT_SHAPES_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_NODE_OUT_SHAPES_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class NodeOutShapes : public opt::NodePass {
|
||||
public:
|
||||
NodeOutShapes() : NodePass("node_out_shapes") {}
|
||||
~NodeOutShapes() override = default;
|
||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node);
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_NODE_OUT_SHAPES_H_
|
Loading…
Reference in New Issue