add node output and shape pass

This commit is contained in:
zhujingxuan 2021-04-26 11:34:29 +08:00
parent 55c71a6034
commit beeb36d075
7 changed files with 404 additions and 0 deletions

View File

@ -59,6 +59,9 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/tf_gelu_fusion.cc
../optimizer/fusion/onnx_gelu_fusion.cc
../optimizer/fusion/squeeze_fusion.cc
../optimizer/fisson/fisson_util.cc
../optimizer/fisson/iter_node_outputs.cc
../optimizer/fisson/node_out_shapes.cc
../optimizer/graph/conv1d_inout_adjust_pass.cc
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc

View File

@ -0,0 +1,129 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <unordered_set>
#include <unordered_map>
#include <memory>
#include "tools/optimizer/fisson/fisson_util.h"
#include "mindspore/core/base/core_ops.h"
#include "src/common/utils.h"
#include "tools/common/node_util.h"
namespace mindspore {
using lite::converter::FmkType;
namespace opt {
std::unordered_map<std::string, std::vector<AnfNodePtr>> g_graph_nodes_output = {};
std::unordered_map<std::string, std::vector<std::vector<ShapeVector>>> g_graph_nodes_out_shapes = {};
AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
const std::string &node_name) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(conv_cnode);
int32_t nodes_num = conv_outputs.size();
if (nodes_num != split_info.out_num) {
MS_LOG(ERROR) << "Conv outputs has wrong input size";
return nullptr;
}
// the inputs of concate are from the outputs of conv
std::vector<AnfNodePtr> concate_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
for (int32_t i = 0; i < nodes_num; i++) {
concate_inputs.push_back(conv_outputs[i]);
}
auto concate_cnode = func_graph->NewCNode(concate_inputs);
MS_EXCEPTION_IF_NULL(concate_cnode);
concate_cnode->set_fullname_with_scope(node_name + "_Concat");
concate_cnode->set_scope(conv_cnode->scope());
return concate_cnode;
}
int32_t GetCOutAxis(int32_t format) {
switch (format) {
case schema::Format_KHWC:
return 0;
case schema::Format_CHWK:
return 3;
case schema::Format_NCHW:
return 0;
default:
MS_LOG(ERROR) << "Do not support format: " << format << " now.";
return -1;
}
}
int32_t GetCInAxis(int32_t format) {
switch (format) {
case schema::Format_KHWC:
return 3;
case schema::Format_CHWK:
return 0;
default:
MS_LOG(ERROR) << "Do not support format: " << format << " now.";
return -1;
}
}
int32_t GetAxis(int32_t axis, int32_t format, const SplitInfo &split_info) {
switch (split_info.primitive_type) {
case mindspore::schema::PrimitiveType_Conv2DFusion:
if (axis == CuttingStragedy::CUT_C_OUT) {
return GetCOutAxis(format);
} else if (axis == CuttingStragedy::CUT_C_IN) {
return GetCInAxis(format);
} else {
MS_LOG(ERROR) << "Only channel_in and channel_out need to transform.";
}
break;
default:
MS_LOG(ERROR) << "Now, do not support the type : " << split_info.primitive_type;
}
return -1;
}
AnfNodePtr CreateOutputsOfAddN(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
const std::string &node_name) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(conv_cnode);
int32_t nodes_num = conv_outputs.size();
if (nodes_num != split_info.out_num) {
MS_LOG(ERROR) << "Conv outputs has wrong input size";
return nullptr;
}
// the inputs of addn are from the outputs of conv
std::vector<AnfNodePtr> addn_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))};
for (int32_t i = 0; i < nodes_num; i++) {
addn_inputs.push_back(conv_outputs[i]);
}
auto addn_cnode = func_graph->NewCNode(addn_inputs);
MS_EXCEPTION_IF_NULL(addn_cnode);
addn_cnode->set_fullname_with_scope(node_name + "_AddN");
addn_cnode->set_scope(conv_cnode->scope());
return addn_cnode;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,68 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_FISSON_UTIL_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_FISSON_UTIL_H_
#include <vector>
#include <string>
#include <unordered_map>
#include "schema/inner/model_generated.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/converter_flags.h"
#include "mindspore/lite/include/context.h"
#include "mindspore/lite/include/lite_types.h"
namespace mindspore {
using mindspore::schema::PrimitiveType;
namespace opt {
extern std::unordered_map<std::string, std::vector<AnfNodePtr>> g_graph_nodes_output;
extern std::unordered_map<std::string, std::vector<std::vector<ShapeVector>>> g_graph_nodes_out_shapes;
struct SplitInfo {
int32_t axis;
int32_t out_num;
std::vector<int32_t> size_splits;
std::vector<int32_t> extend_top;
std::vector<int32_t> extend_bottom;
std::vector<mindspore::lite::DeviceType> dev_types;
int32_t in_num_conv;
int32_t fmk_type;
std::vector<int32_t> weight_channel;
PrimitiveType primitive_type;
};
typedef enum { CUT_N, CUT_H, CUT_W, CUT_C_IN, CUT_C_OUT, CUT_NONE } CuttingStragedy;
void GetMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num,
std::vector<AnfNodePtr> *outputs);
AnfNodePtr CreateOutputsOfConcat(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
const std::string &node_name);
void CreateOutputsOfSplitWithOverlap(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
std::vector<AnfNodePtr> *split_outputs, const SplitInfo &split_info,
const std::string &node_name);
void GetCNodeShapeInfo(const FuncGraphPtr &func_graph, int32_t fmk_type);
AnfNodePtr CreateOutputsOfAddN(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode,
const std::vector<AnfNodePtr> &conv_outputs, const SplitInfo &split_info,
const std::string &node_name);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_FISSON_UTIL_H_

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fisson/iter_node_outputs.h"
#include <vector>
#include "tools/optimizer/fisson/fisson_util.h"
namespace mindspore {
namespace opt {
AnfNodePtr IterNodeOutputs::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
return nullptr;
}
if (!utils::isa<CNodePtr>(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
for (auto input_node : inputs) {
if (!utils::isa<CNodePtr>(input_node)) {
continue;
}
auto input_cnode = input_node->cast<CNodePtr>();
auto name = input_cnode->fullname_with_scope();
if (g_graph_nodes_output.find(name) != g_graph_nodes_output.end()) {
std::vector<AnfNodePtr>::iterator it;
it = find(g_graph_nodes_output[name].begin(), g_graph_nodes_output[name].end(), node);
if (it != g_graph_nodes_output[name].end()) {
continue;
}
}
g_graph_nodes_output[name].push_back(node);
}
return nullptr;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,35 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/anf.h"
#include "mindspore/ccsrc/backend/optimizer/common/node_pass.h"
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ITER_NODE_OUTPUTS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ITER_NODE_OUTPUTS_H_
namespace mindspore {
namespace opt {
class IterNodeOutputs : public opt::NodePass {
public:
IterNodeOutputs() : NodePass("iter_node_outputs") {}
~IterNodeOutputs() override = default;
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ITER_NODE_OUTPUTS_H_

View File

@ -0,0 +1,81 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fisson/node_out_shapes.h"
#include <vector>
#include <string>
#include "tools/optimizer/fisson/fisson_util.h"
namespace mindspore {
namespace opt {
AnfNodePtr NodeOutShapes::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
return nullptr;
}
std::vector<ShapeVector> input_shapes;
std::vector<ShapeVector> output_shapes;
if (!utils::isa<CNodePtr>(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
// input
for (auto input_node : cnode->inputs()) {
if (utils::isa<CNodePtr>(input_node) || utils::isa<ParameterPtr>(input_node)) {
auto in_shape = input_node->Shape();
if (in_shape == nullptr) {
MS_LOG(ERROR) << "The shape is null.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
if (utils::isa<abstract::ShapePtr>(in_shape)) {
const auto &shape = in_shape->cast<abstract::ShapePtr>()->shape();
input_shapes.push_back(shape);
} else {
MS_LOG(ERROR) << "currently not support tuple";
}
}
}
// output
auto out_shape = cnode->Shape();
if (out_shape == nullptr) {
MS_LOG(ERROR) << "The shape is null.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
if (utils::isa<abstract::TupleShapePtr>(out_shape)) {
auto shape = out_shape->cast<abstract::TupleShapePtr>();
for (size_t i = 0; i < shape->size(); ++i) {
const auto &shape_ptr = (*shape)[i];
if (!utils::isa<abstract::ShapePtr>(shape_ptr)) {
MS_LOG(ERROR) << "shape_ptr is not ShapePtr.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
output_shapes.push_back(shape_ptr->cast<abstract::ShapePtr>()->shape());
}
} else if (utils::isa<abstract::ShapePtr>(out_shape)) {
const auto &shape = out_shape->cast<abstract::ShapePtr>()->shape();
output_shapes.push_back(shape);
}
std::string name = cnode->fullname_with_scope();
g_graph_nodes_out_shapes[name].push_back(input_shapes);
g_graph_nodes_out_shapes[name].push_back(output_shapes);
return nullptr;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,35 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/anf.h"
#include "mindspore/ccsrc/backend/optimizer/common/node_pass.h"
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_NODE_OUT_SHAPES_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_NODE_OUT_SHAPES_H_
namespace mindspore {
namespace opt {
class NodeOutShapes : public opt::NodePass {
public:
NodeOutShapes() : NodePass("node_out_shapes") {}
~NodeOutShapes() override = default;
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node);
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_NODE_OUT_SHAPES_H_