add tf while pass

use tensor name find order
This commit is contained in:
mengyuanli 2020-12-17 14:44:56 +08:00
parent 28052ad188
commit 89f96e347b
38 changed files with 1524 additions and 388 deletions

View File

@ -20,6 +20,7 @@
#include <map>
#include "tools/converter/quantizer/quantize_util.h"
#include "src/ops/assert_op.h"
#include "src/ops/space_to_batch.h"
#include "src/ops/space_to_batch_nd.h"
#include "src/ops/conv2d.h"
@ -614,6 +615,13 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<Sqrt>(prim, inputs, quantType);
} else if (op_type == "Greater") {
return NewPrimitiveC<Greater>(prim, inputs, quantType);
} else if (op_type == "Switch") {
return NewPrimitiveC<Switch>(prim, inputs, quantType);
} else if (op_type == "Partial") {
return NewPrimitiveC<Partial>(prim, inputs, quantType);
} else if (op_type == "Merge") {
return NewPrimitiveC<Merge>(prim, inputs, quantType);
#ifdef SUPPORT_TRAIN
} else if (op_type == "SoftmaxCrossEntropyWithLogits") {
return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType);
@ -955,6 +963,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) Merge(primitive);
case schema::PrimitiveType_Partial:
return new (std::nothrow) Partial(primitive);
case schema::PrimitiveType_Assert:
return new (std::nothrow) AssertOP(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
return new (std::nothrow) ActivationGrad(primitive);

View File

@ -156,7 +156,8 @@ kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector<lite::Tensor
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(desc.type == schema::PrimitiveType_Transpose);
MS_ASSERT(desc.type == schema::PrimitiveType_Transpose || desc.type == schema::PrimitiveType_Nchw2Nhwc ||
desc.type == schema::PrimitiveType_Nhwc2Nchw);
if (opParameter == nullptr) {
MS_LOG(ERROR) << "desc type is not Transpose";
return nullptr;

View File

@ -200,6 +200,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/while_pass.cc
)
endif()
### train

View File

@ -7,7 +7,7 @@ rcnn-ilsvrc13-9.onnx
mobilenetv2-7.onnx
shufflenet-v2-10.onnx
squeezenet1.1-7.onnx
densenet-9.onnx
#densenet-9.onnx
ml_table_detection_fp32.onnx
ml_table_segment.onnx
googlenet-9.onnx
@ -27,7 +27,7 @@ psenet_lite_mbv2.onnx;1,32,32,3
super-resolution-10.onnx;1,224,224,1
tinyyolov2-8.onnx;1,416,416,3
ml_2012_ocr_cn.onnx
ml_2012_ocr_cn_noLSTM.onnx
#ml_2012_ocr_cn_noLSTM.onnx
candy-9.onnx
mosaic-9.onnx
pointilism-9.onnx

View File

@ -7,7 +7,7 @@ emotion-ferplus-8.onnx 1
mobilenetv2-7.onnx 8
shufflenet-v2-10.onnx 5
squeezenet1.1-7.onnx 1
densenet-9.onnx 6
#densenet-9.onnx 6
ml_table_detection_fp32.onnx 2
ml_table_segment.onnx 2
googlenet-9.onnx 3
@ -27,7 +27,7 @@ mnist-8.onnx 10
#super-resolution-10.onnx 1
#tinyyolov2-8.onnx 0.3
ml_2012_ocr_cn.onnx 200
ml_2012_ocr_cn_noLSTM.onnx 1
#ml_2012_ocr_cn_noLSTM.onnx 1
candy-9.onnx 5
mosaic-9.onnx 4
pointilism-9.onnx 3

View File

@ -28,6 +28,8 @@
#include "src/tensor.h"
#include "src/param_value_lite.h"
#include "src/common/utils.h"
#include "src/ops/partial.h"
#include "tools/common/graph_util.h"
namespace mindspore::lite {
void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) {
@ -73,7 +75,7 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) {
if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) ||
IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) {
hasDepend = true;
bool maskOut = (dependNode->inputs().size() == 3) ? true : false;
bool maskOut = (dependNode->inputs().size() == 3);
for (size_t j = 1; j < dependNode->inputs().size(); ++j) {
AnfNodePtr dependInputNode = dependNode->input(j);
if (dependInputNode->isa<CNode>()) {
@ -172,22 +174,50 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
return RET_OK;
}
void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
for (auto node : graph_input_nodes_) {
std::vector<schema::CNodeT *> AnfExporter::GetSubgraphNodes(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index) {
std::vector<schema::CNodeT *> subgraph_nodes{};
subgraph_nodes.resize(meta_graphT->subGraph.at(subgraph_index)->nodeIndices.size());
std::transform(meta_graphT->subGraph.at(subgraph_index)->nodeIndices.begin(),
meta_graphT->subGraph.at(subgraph_index)->nodeIndices.end(), subgraph_nodes.begin(),
[&meta_graphT](const uint32_t idx) { return meta_graphT->nodes.at(idx).get(); });
return subgraph_nodes;
}
int AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index) {
auto &subgraph = meta_graphT->subGraph.at(subgraph_index);
auto subgraph_nodes = GetSubgraphNodes(meta_graphT, subgraph_index);
std::vector<schema::CNodeT *> subgraph_input_nodes{};
for (auto &node : subgraph_nodes) {
if (IsContain(graph_input_nodes_, node)) {
subgraph_input_nodes.push_back(node);
}
}
std::vector<schema::TensorT *> subgraph_inputs{};
for (auto &node : subgraph_input_nodes) {
for (auto input : node->inputIndex) {
auto tensor = meta_graphT->allTensors[input].get();
if (tensor->nodeType != schema::NodeType_CNode && tensor->data.empty()) {
tensor->nodeType = schema::NodeType_ValueNode;
tensor->format = schema::Format_NHWC;
if (!IsContain(meta_graphT->inputIndex, input)) {
meta_graphT->inputIndex.emplace_back(input);
if (!IsContain(subgraph->inputIndices, input)) {
if (subgraph_index == kMainGraphIndex) {
meta_graphT->inputIndex.push_back(input);
}
subgraph->inputIndices.push_back(input);
subgraph_inputs.push_back(tensor);
}
}
}
}
return RET_OK;
}
int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const std::unique_ptr<schema::SubGraphT> &sub_graphT,
schema::CNodeT *return_node) {
MS_ASSERT(nullptr != meta_graphT);
MS_ASSERT(nullptr != return_node);
@ -202,28 +232,62 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_pt
MS_LOG(ERROR) << "obtain outputs failed";
return ret;
}
} else if (input_node->isa<Parameter>()) {
MS_LOG(INFO) << "the node " << input_node->fullname_with_scope().c_str() << "is parameter node";
continue;
} else {
MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node";
return RET_ERROR;
}
}
for (unsigned int &i : return_node->inputIndex) {
meta_graphT->outputIndex.push_back(i);
if (subgraph_index == kMainGraphIndex) {
meta_graphT->outputIndex.push_back(i);
}
meta_graphT->subGraph.at(subgraph_index)->outputIndices.push_back(i);
}
return RET_OK;
}
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) {
auto cnodes = func_graph->GetOrderedCnodes();
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index, bool keep_graph, bool copy_primitive,
const std::shared_ptr<AnfNode> &partial_anode) {
int ret = RET_OK;
meta_graphT->subGraph.emplace_back(std::make_unique<schema::SubGraphT>());
auto &sub_graphT = meta_graphT->subGraph.at(subgraph_index);
auto subgraph_name = func_graph->get_attr("graph_name");
MS_ASSERT(subgraph_name != nullptr);
sub_graphT->name = GetValue<std::string>(subgraph_name);
auto cnodes = func_graph->GetOrderedCnodes();
for (const auto &cnode : cnodes) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
ret = RET_MEMORY_FAILED;
break;
auto fg = GetValueNode<FuncGraphPtr>(cnode->input(0));
if (fg != nullptr) {
auto partial_cnode = CreatePartialCnode(fg, cnode);
primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(partial_cnode->input(0));
auto primT = primitive_c->primitiveT();
auto pos = fg_subgraph_map.find(fg);
if (pos != fg_subgraph_map.end()) {
primT->value.AsPartial()->subGraphIndex = fg_subgraph_map.at(fg);
} else {
size_t next_subgraph_index = fg_subgraph_map.size() + 1;
fg_subgraph_map.insert(std::pair<FuncGraphPtr, int>{fg, next_subgraph_index});
primT->value.AsPartial()->subGraphIndex = next_subgraph_index;
ret = ExportSubgraph(fg, meta_graphT, next_subgraph_index, keep_graph, copy_primitive, cnode);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ExportSubgraph failed";
break;
}
}
} else {
MS_LOG(ERROR) << "primitive_c is nullptr";
ret = RET_MEMORY_FAILED;
break;
}
}
#ifdef SUPPORT_TRAIN
RemoveIfMakeTuple(cnode);
RemoveIfDepend(cnode);
@ -249,13 +313,14 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
}
if (primT->value.type == schema::PrimitiveType_Return) {
node->name = "return_node";
ret = SetGraphoutputIndex(cnode, meta_graphT, node.get());
ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, sub_graphT, node.get());
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetOpOutputN failed";
break;
}
continue;
}
node->nodeType = schema::NodeType_CNode;
node->name = cnode->fullname_with_scope();
if (copy_primitive) {
@ -281,21 +346,45 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
if (!keep_graph) {
primitive_c->ClearPrimitiveT();
}
meta_graphT->nodes.emplace_back(std::move(node));
meta_graphT->nodes.push_back(std::move(node));
meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx++);
}
if (ret != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
return ret;
}
ret = SetGraphInputIndex(meta_graphT, subgraph_index);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetGraphInputIndex failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
return ret;
}
ret = SetSubgraphTensorIndices(meta_graphT.get());
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetSubgraphTensorIndices failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
return ret;
}
return RET_OK;
}
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) {
static int subgraph_index = 0;
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive);
if (ret != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
return nullptr;
}
// set graph input tensors
SetGraphInputIndex(meta_graphT);
return meta_graphT.release();
}
int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode) {
std::string input_name = input_anode->fullname_with_scope();
auto input_cnode = utils::cast<CNodePtr>(input_anode);
if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) {
#ifndef SUPPORT_TRAIN
if (node_id_map_.find(input_name) != node_id_map_.end()) {
@ -343,11 +432,11 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode,
input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0
iter = node_id_map_.find(input_index_key);
if (iter == node_id_map_.end()) {
MS_LOG(ERROR) << "Can not find get_item output tensor" << input_index_key;
MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key;
return RET_ERROR;
}
#else
MS_LOG(ERROR) << "Can not find get_item output tensor" << input_index_key;
MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key;
return RET_ERROR;
#endif
}
@ -367,6 +456,7 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> &input_ano
}
auto paramTensor = std::make_unique<schema::TensorT>();
paramTensor->format = schema::Format_NHWC;
paramTensor->name = paramNode->name();
auto abstractBase = paramNode->abstract();
if (abstractBase == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name();
@ -518,6 +608,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
} else if (value->isa<FuncGraph>()) {
MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph";
return RET_OK;
} else {
MS_LOG(ERROR) << "Not support value type , need add support.";
return RET_ERROR;
@ -644,6 +737,20 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
}
}
bool AnfExporter::HasPrimitiveCNode(const AnfNodePtr &node) {
MS_ASSERT(node != nullptr);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return false;
}
auto prim = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (prim == nullptr) {
return false;
}
return true;
}
bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type) {
MS_ASSERT(node != nullptr);
auto cnode = node->cast<CNodePtr>();
@ -658,6 +765,47 @@ bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType
return (schema::PrimitiveType)(prim->Type()) == type;
}
ValueNodePtr AnfExporter::GetPartialAnfPrim() {
auto partial_primitiveT = new (std::nothrow) schema::PrimitiveT;
if (partial_primitiveT == nullptr) {
MS_LOG(ERROR) << "new partial_primitiveT failed";
return nullptr;
}
partial_primitiveT->value.type = schema::PrimitiveType_Partial;
partial_primitiveT->value.value = new (std::nothrow) schema::PartialT;
if (partial_primitiveT->value.value == nullptr) {
MS_LOG(ERROR) << "new PartialT failed";
return nullptr;
}
auto partial_prim = std::make_shared<lite::Partial>(partial_primitiveT);
ValueNodePtr partial_anf_prim = NewValueNode(partial_prim);
return partial_anf_prim;
}
CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr node) {
if (utils::isa<CNodePtr>(node)) {
auto cnode = utils::cast<CNodePtr>(node);
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c != nullptr) {
return cnode;
}
auto partial_anf_prim_vnode = GetPartialAnfPrim();
auto cnode_input = cnode->inputs();
cnode_input.insert(cnode_input.begin(), partial_anf_prim_vnode);
cnode->set_inputs(cnode_input);
return cnode;
} else if (utils::isa<ValueNodePtr>(node)) {
auto partial_anf_prim_vnode = GetPartialAnfPrim();
std::vector<AnfNodePtr> inputs{partial_anf_prim_vnode, node};
auto cnode = fg->NewCNode(inputs);
return cnode;
} else {
MS_LOG(ERROR) << "failed to create partial cnode.";
return nullptr;
}
}
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) {
AnfExporter anf_exporter;
return anf_exporter.Export(func_graph, keep_graph, copy_primitive);

View File

@ -27,6 +27,10 @@
#include "tools/converter/converter_context.h"
namespace mindspore::lite {
constexpr const int kPartialMinSize = 3;
constexpr const int kMainGraphIndex = 0;
class AnfExporter {
public:
AnfExporter() = default;
@ -45,17 +49,28 @@ class AnfExporter {
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
int ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *return_node);
int SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const size_t &subgraph_index);
int SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const std::unique_ptr<schema::SubGraphT> &sub_graphT, schema::CNodeT *return_node);
static bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type);
static bool HasPrimitiveCNode(const AnfNodePtr &node);
static int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::shared_ptr<PrimitiveC> &primitive,
const std::unique_ptr<schema::CNodeT> &dst_node);
int ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index, bool keep_graph, bool copy_primitive,
const std::shared_ptr<AnfNode> &partial_anode = nullptr);
ValueNodePtr GetPartialAnfPrim();
CNodePtr CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr cnode);
std::vector<schema::CNodeT *> GetSubgraphNodes(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index);
private:
std::map<std::string, int> node_id_map_;
std::vector<schema::CNodeT *> graph_input_nodes_;
std::map<FuncGraphPtr, int> fg_subgraph_map;
uint32_t node_idx = 0;
};
// by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT.
// but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify

View File

@ -272,18 +272,40 @@ STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTe
continue;
}
}
// update graph input indexes
// update graph input indices
for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
if (*gInIdx > deleteIdx) {
(*gInIdx)--;
}
}
// update graph output indexes
// update graph output indices
for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
if (*gOutIdx > deleteIdx) {
(*gOutIdx)--;
}
}
for (auto &subgraph : graphT->subGraph) {
// update subgraph input indices
for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
if (*gInIdx > deleteIdx) {
(*gInIdx)--;
}
}
// update subgraph output indices
for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
if (*gOutIdx > deleteIdx) {
(*gOutIdx)--;
}
}
// update subgraph output indices
for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
if (*idx > deleteIdx) {
(*idx)--;
}
}
}
// update nodes indexes
for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
// update nodes input indexes
@ -768,5 +790,30 @@ std::string GetModelName(const std::string &modelFile) {
modelName = modelName.substr(0, modelName.find_last_of('.'));
return modelName;
}
int SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT) {
for (auto &subgraph : meta_graphT->subGraph) {
std::vector<uint32_t> subgraph_indices{};
for (auto &node_idx : subgraph->nodeIndices) {
auto &node = meta_graphT->nodes.at(node_idx);
for (auto &input_idx : node->inputIndex) {
if (IsContain(subgraph_indices, input_idx)) {
continue;
} else {
subgraph_indices.push_back(input_idx);
}
}
for (auto &output_idx : node->outputIndex) {
if (IsContain(subgraph_indices, output_idx)) {
continue;
} else {
subgraph_indices.push_back(output_idx);
}
}
}
subgraph->tensorIndices.assign(subgraph_indices.begin(), subgraph_indices.end());
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -92,6 +92,8 @@ STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNo
STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node);
STATUS SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT);
std::string GetModelName(const std::string &modelFile);
} // namespace lite
} // namespace mindspore

View File

@ -59,6 +59,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/slice_prepose_pass.cc
../optimizer/graph/mindir_adjust_pass.cc
../optimizer/graph/onnx_inputs_adjust_pass.cc
../optimizer/graph/while_pass.cc
)
add_subdirectory(../anf_importer anf_importer)

View File

@ -42,6 +42,7 @@
#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h"
#include "tools/optimizer/graph/infershape_pass.h"
#include "tools/optimizer/graph/slice_prepose_pass.h"
#include "tools/optimizer/graph/while_pass.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/weight_quantizer.h"
@ -52,18 +53,21 @@ AnfTransform::AnfTransform() = default;
AnfTransform::~AnfTransform() = default;
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const converter::Flags *config) {
FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) {
MS_ASSERT(nullptr != old_graph);
if (config == nullptr) {
MS_LOG(ERROR) << "config shoud be specified";
MS_LOG(ERROR) << "config should be specified";
return nullptr;
}
if (old_graph->has_flag("HasTransformed")) {
old_graph->set_flag("HasTransformed", false);
return old_graph;
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
// mindir pre adjustment
if (config->fmk == converter::FmkType_MS) {
auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>();
mindir_adjust_pass->SetFmkType(config->fmk);
@ -85,7 +89,12 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
}
}
// for now - trainning is not supporting fuse operations
if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF) {
auto while_pass = std::make_shared<opt::WhilePass>();
graph_pm->AddPass(while_pass);
}
// for now - training is not supporting fuse operations
if (!config->trainModel) {
// remove quantdtype when awaretraining
fusion_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
@ -191,7 +200,46 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
return nullptr;
}
}
return new_graph;
}
STATUS AnfTransform::GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphPtrList *subgraphs,
std::vector<ValueNodePtr> *vnodes) {
auto nodes = TopoSort(main_graph->get_return());
for (auto &node : nodes) {
auto fg = GetValueNode<FuncGraphPtr>(node);
if (fg) {
vnodes->push_back(utils::cast<ValueNodePtr>(node));
subgraphs->push_back(fg);
}
}
return RET_OK;
}
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
// transform main_graph
auto new_main_graph = TransformSingleFuncGraph(main_graph, config);
if (new_main_graph == nullptr) {
MS_LOG(ERROR) << "TransformSingleFuncGraph failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
// transform sub_graph
FuncGraphPtrList subgraphs{};
std::vector<ValueNodePtr> vnodes{};
int ret = GetAllFuncGraph(main_graph, &subgraphs, &vnodes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "GetAllFuncGraph failed " << ret;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
return nullptr;
}
for (size_t i = 0; i < subgraphs.size(); i++) {
auto new_graph = Transform(subgraphs.at(i), config);
new_graph->set_flag("HasTransformed", true);
vnodes.at(i)->set_value(new_graph);
}
return new_main_graph;
}
} // namespace mindspore::lite

View File

@ -18,6 +18,7 @@
#define MINDSPORE_LITE_TOOLS_CONVERTER_ANF_TRANSFORM_H
#include <memory>
#include <vector>
#include "schema/inner/model_generated.h"
#include "tools/common/storage.h"
#include "tools/converter/converter_flags.h"
@ -34,6 +35,9 @@ class AnfTransform {
FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
private:
STATUS GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphPtrList *subgraphs,
std::vector<ValueNodePtr> *vnodes);
FuncGraphPtr TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
std::unique_ptr<quant::Quantizer> mQuantizer = nullptr;
};
} // namespace lite

View File

@ -67,6 +67,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
int status = modelImporter->Import(flag);
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
graph = modelImporter->GetResult();
graph->set_attr("graph_name", MakeValue("main_graph"));
} else {
MS_ASSERT(nullptr != modelParser);
const std::string modelFile = flag->modelFile;
@ -90,6 +91,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
MS_LOG(ERROR) << "Export to meta graph return nullptr";
return nullptr;
}
// transform
transform->SetGraphDef(meta_graph);
auto status = transform->Transform(*flag);

View File

@ -16,6 +16,7 @@
#include "tools/converter/graphdef_transform.h"
#include <string>
#include <algorithm>
#include "schema/model_generated.h"
#include "src/common/log_adapter.h"
#include "tools/converter/converter_flags.h"
@ -37,9 +38,21 @@
#include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h"
#include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h"
#include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h"
#include "tools/converter/legacy_optimizer/graph/switch_pass.h"
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
#include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
using std::string;
namespace mindspore::lite {
std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() {
std::vector<schema::CNodeT *> old_nodes{};
old_nodes.resize(graphDefT->nodes.size());
std::transform(graphDefT->nodes.begin(), graphDefT->nodes.end(), old_nodes.begin(),
[](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); });
return old_nodes;
}
GraphDefTransform::GraphDefTransform() = default;
GraphDefTransform::~GraphDefTransform() = default;
@ -48,141 +61,232 @@ void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _
int GraphDefTransform::Transform(const converter::Flags &ctx) {
STATUS status;
{
Optimizer unusedOpRemoveOptimizer;
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
if (!ctx.trainModel) {
unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass());
}
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass());
status = unusedOpRemoveOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed";
return status;
}
}
// topological sorting
{
Optimizer topologicalOptimizer;
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
status = topologicalOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
}
// generate and infer quant parameters
{
Optimizer inferQuantParamPass;
inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass());
inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass());
status = inferQuantParamPass.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
}
// postconvert pass
{
Optimizer fusionOptimizer;
if (!ctx.trainModel) {
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass();
if (batch_norm_scale_pass == nullptr) {
MS_LOG(ERROR) << "new batch_norm_scale_pass failed.";
return RET_ERROR;
if (ctx.fmk != converter::FmkType_TF) {
{
auto old_nodes = GetGraphNodes();
Optimizer unusedOpRemoveOptimizer;
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
if (!ctx.trainModel) {
unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass());
}
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass());
unusedOpRemoveOptimizer.AddPass(new SubgraphNodePass(old_nodes));
status = unusedOpRemoveOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed";
return status;
}
batch_norm_scale_pass->SetFmk(ctx.fmk);
fusionOptimizer.AddPass(batch_norm_scale_pass);
}
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
status = fusionOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed";
return status;
// topological sorting
{
Optimizer topologicalOptimizer;
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
status = topologicalOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
}
}
// format transform
{
Optimizer formatTransOptimizer;
auto formatTransPass = new (std::nothrow) FormatTransPass();
if (formatTransPass == nullptr) {
MS_LOG(ERROR) << "new formatTransPass failed";
return RET_MEMORY_FAILED;
// generate and infer quant parameters
{
Optimizer inferQuantParamPass;
inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass());
inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass());
status = inferQuantParamPass.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
}
formatTransPass->SetQuantType(ctx.quantType);
formatTransPass->SetFmk(ctx.fmk);
formatTransOptimizer.AddPass(formatTransPass);
formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass());
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass());
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) {
formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass());
// postconvert pass
{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer fusionOptimizer;
if (!ctx.trainModel) {
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass();
if (batch_norm_scale_pass == nullptr) {
MS_LOG(ERROR) << "new batch_norm_scale_pass failed.";
return RET_ERROR;
}
batch_norm_scale_pass->SetFmk(ctx.fmk);
fusionOptimizer.AddPass(batch_norm_scale_pass);
}
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes));
status = fusionOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed";
return status;
}
}
// format transform
{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer formatTransOptimizer;
auto formatTransPass = new (std::nothrow) FormatTransPass();
if (formatTransPass == nullptr) {
MS_LOG(ERROR) << "new formatTransPass failed";
return RET_MEMORY_FAILED;
}
formatTransPass->SetQuantType(ctx.quantType);
formatTransPass->SetFmk(ctx.fmk);
formatTransOptimizer.AddPass(formatTransPass);
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass());
status = formatTransOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
return status;
}
}
{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer formatTransOptimizer;
auto formatTransPass = new (std::nothrow) FormatTransPass();
if (formatTransPass == nullptr) {
MS_LOG(ERROR) << "new formatTransPass failed";
return RET_MEMORY_FAILED;
}
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass());
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = formatTransOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
return status;
}
}
status = formatTransOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
return status;
{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer formatTransOptimizer;
auto formatTransPass = new (std::nothrow) FormatTransPass();
if (formatTransPass == nullptr) {
MS_LOG(ERROR) << "new formatTransPass failed";
return RET_MEMORY_FAILED;
}
if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) {
formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
}
status = formatTransOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
return status;
}
}
{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer fusionOptimizer;
fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
fusionOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = fusionOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
return status;
}
}
// do quantization
{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer tensorQuantOptimizer;
tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass());
tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass());
tensorQuantOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = tensorQuantOptimizer.Run(graphDefT);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantize failed!";
return status;
}
}
// insert quantNode and deQuantNode
{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer quantNodeOptimizer;
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
if (dTypeTransPass == nullptr) {
MS_LOG(ERROR) << "new dTypeTransPass failed";
return RET_MEMORY_FAILED;
}
dTypeTransPass->SetInputDataDType(ctx.inputDataType);
dTypeTransPass->SetOutputDataDType(ctx.outputDataType);
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass());
status = quantNodeOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
return status;
}
auto old_nodes2 = GetGraphNodes();
quantNodeOptimizer.AddPass(dTypeTransPass);
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2));
status = quantNodeOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
return status;
}
}
}
// switch pass
{
Optimizer fusionOptimizer;
fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
status = fusionOptimizer.Run(graphDefT);
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer switchOptimizer;
switchOptimizer.AddPass(new (std::nothrow) SwitchPass());
switchOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
switchOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = switchOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run switch graphPasses Failed";
return status;
}
}
// do quantization
// subgraph tensor pass
{
Optimizer tensorQuantOptimizer;
tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass());
tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass());
status = tensorQuantOptimizer.Run(graphDefT);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantize failed!";
return status;
}
}
// insert quantNode and deQuantNode
{
Optimizer quantNodeOptimizer;
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
if (dTypeTransPass == nullptr) {
MS_LOG(ERROR) << "new dTypeTransPass failed";
return RET_MEMORY_FAILED;
}
dTypeTransPass->SetInputDataDType(ctx.inputDataType);
dTypeTransPass->SetOutputDataDType(ctx.outputDataType);
quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass());
quantNodeOptimizer.AddPass(dTypeTransPass);
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
status = quantNodeOptimizer.Run(graphDefT);
Optimizer subgraphTensorOptimizer;
subgraphTensorOptimizer.AddPass(new (std::nothrow) SubgraphTensorPass());
status = subgraphTensorOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run subgraph tensor pass Failed";
return status;
}
}
// tensor name
{
// init old node indecies
auto old_nodes = GetGraphNodes();
Optimizer nameOptimizer;
nameOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
nameOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
nameOptimizer.AddPass(new (std::nothrow) TensorNamePass());
status = nameOptimizer.Run(graphDefT);
@ -192,16 +296,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
// topological sorting
{
Optimizer topologicalOptimizer;
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
status = topologicalOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
}
return RET_OK;
}
} // namespace mindspore::lite
} // namespace mindspore::lite

View File

@ -18,6 +18,7 @@
#define MINDSPORE_LITE_TOOLS_CONVERTER_GRAPHDEF_TRANSFORM_H
#include <memory>
#include <vector>
#include "tools/converter/optimizer.h"
#include "tools/converter/quantizer/quantizer.h"
#include "schema/inner/model_generated.h"
@ -39,6 +40,7 @@ class GraphDefTransform {
inline schema::MetaGraphT *GetOutput() { return graphDefT; }
protected:
std::vector<schema::CNodeT *> GetGraphNodes();
schema::MetaGraphT *graphDefT = nullptr;
Optimizer *optimizer = nullptr;
};

View File

@ -15,6 +15,9 @@ file(GLOB GRAPH_PASS
${CMAKE_CURRENT_SOURCE_DIR}/global_format_transform_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/switch_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/subgraph_node_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tensor_pass.cc
)
set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
add_library(graph_pass_mid OBJECT ${GRAPH_PASS})

View File

@ -0,0 +1,76 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <algorithm>
#include <memory>
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
#include "src/common/log_adapter.h"
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
void SubgraphNodePass::UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
for (auto &subgraph : graph->subGraph) {
for (auto &idx : subgraph->nodeIndices) {
if (idx > node_idx) {
idx--;
}
}
}
}
STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
std::vector<schema::CNodeT *> new_nodes{};
std::transform(graph->nodes.begin(), graph->nodes.end(), std::back_inserter(new_nodes),
[](std::unique_ptr<CNodeT> &node) { return node.get(); });
for (auto it = old_nodes_.begin(); it != old_nodes_.end();) {
if (!IsContain(new_nodes, *it)) {
size_t node_idx = it - old_nodes_.begin();
for (auto &subgraph : graph->subGraph) {
auto node_idx_pos = std::find(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), node_idx);
if (node_idx_pos != subgraph->nodeIndices.end()) {
subgraph->nodeIndices.erase(node_idx_pos);
UpdateSubgraphNodeIndices(node_idx, graph);
break;
}
}
it = old_nodes_.erase(it);
} else {
it++;
}
}
for (uint32_t i = 0; i < new_nodes.size(); i++) {
if (!IsContain(old_nodes_, new_nodes[i])) {
for (auto &subgraph : graph->subGraph) {
if (IsContain(subgraph->nodeIndices, i - 1) || IsContain(subgraph->nodeIndices, i + 1)) {
subgraph->nodeIndices.push_back(old_nodes_.size());
old_nodes_.push_back(new_nodes[i]);
}
}
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_NODE_PASS_H
#define MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_NODE_PASS_H
#include <vector>
#include <utility>
#include "tools/converter/optimizer.h"
namespace mindspore {
namespace lite {
class SubgraphNodePass : public GraphPass {
public:
explicit SubgraphNodePass(std::vector<schema::CNodeT *> old_nodes) : old_nodes_(std::move(old_nodes)) {}
~SubgraphNodePass() override = default;
STATUS Run(schema::MetaGraphT *graph) override;
private:
void UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph);
std::vector<schema::CNodeT *> old_nodes_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H

View File

@ -0,0 +1,100 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <algorithm>
#include <memory>
#include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
#include "src/common/log_adapter.h"
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
bool SubgraphTensorPass::IsUsing(schema::MetaGraphT *graph, const uint32_t &tensor_idx) {
for (const auto &node : graph->nodes) {
if (IsContain<uint32_t>(node->inputIndex, tensor_idx)) {
return true;
}
if (IsContain<uint32_t>(node->outputIndex, tensor_idx)) {
return true;
}
}
return false;
}
STATUS SubgraphTensorPass::UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx) {
for (const auto &subgraph : graph->subGraph) {
UpdateVec<uint32_t>(&(subgraph->inputIndices), tensor_idx);
UpdateVec<uint32_t>(&(subgraph->outputIndices), tensor_idx);
}
for (const auto &node : graph->nodes) {
UpdateVec<uint32_t>(&(node->inputIndex), tensor_idx);
UpdateVec<uint32_t>(&(node->outputIndex), tensor_idx);
}
UpdateVec<uint32_t>(&(graph->inputIndex), tensor_idx);
UpdateVec<uint32_t>(&(graph->outputIndex), tensor_idx);
return RET_OK;
}
STATUS SubgraphTensorPass::RemoveUselessTensors(schema::MetaGraphT *graph) {
for (auto it = graph->allTensors.begin(); it != graph->allTensors.end();) {
uint32_t idx = it - graph->allTensors.begin();
if (IsUsing(graph, idx)) {
it++;
} else {
it = graph->allTensors.erase(it);
UpdateTensorIdx(graph, idx);
}
}
return RET_OK;
}
STATUS SubgraphTensorPass::SyncMainGraphInputAndOutput(schema::MetaGraphT *graph) {
MS_ASSERT(graph->subGraph.size() > 0);
graph->subGraph[0]->inputIndices.assign(graph->inputIndex.begin(), graph->inputIndex.end());
graph->subGraph[0]->outputIndices.assign(graph->outputIndex.begin(), graph->outputIndex.end());
return RET_OK;
}
STATUS SubgraphTensorPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
int ret = RemoveUselessTensors(graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "RemoveUselessTensors failed, ret: " << ret;
return ret;
}
ret = SetSubgraphTensorIndices(graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret;
return ret;
}
ret = SyncMainGraphInputAndOutput(graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret;
return ret;
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,51 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_TENSOR_PASS_H
#define MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_TENSOR_PASS_H
#include <vector>
#include <utility>
#include "tools/converter/optimizer.h"
namespace mindspore {
namespace lite {
class SubgraphTensorPass : public GraphPass {
public:
SubgraphTensorPass() = default;
~SubgraphTensorPass() override = default;
STATUS Run(schema::MetaGraphT *graph) override;
private:
STATUS RemoveUselessTensors(schema::MetaGraphT *graph);
bool IsUsing(schema::MetaGraphT *graph, const uint32_t &tensor_idx);
STATUS UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx);
STATUS SyncMainGraphInputAndOutput(schema::MetaGraphT *graph);
template <typename T>
void UpdateVec(std::vector<T> *vec, T element) {
for (auto iter = vec->begin(); iter != vec->end(); iter++) {
if (*iter > element) {
(*iter)--;
}
}
}
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H

View File

@ -16,6 +16,7 @@
#include <vector>
#include <map>
#include <algorithm>
#include "tools/converter/legacy_optimizer/graph/switch_pass.h"
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
@ -96,38 +97,6 @@ std::unique_ptr<schema::TensorT> SingleSwitchPass::NewTensor(const std::unique_p
return out_tensor;
}
STATUS SingleSwitchPass::MoveMaxIterationToCond() {
auto &body_subgraph_input = graph_->subGraph.at(body_subgraph_index_)->inputIndices;
for (auto it = body_subgraph_input.begin(); it != body_subgraph_input.end();) {
if (!body_to_cond_partial_node_->inputIndex.empty() && IsContain(body_to_cond_partial_node_->inputIndex, *it)) {
int32_t max_iteration_idx = it - body_subgraph_input.begin();
// get maxiteration tensor
auto &max_iteration_tensor = graph_->allTensors.at(cond_partial_node_->inputIndex.at(max_iteration_idx));
auto all_tensor_idx = std::find(graph_->allTensors.begin(), graph_->allTensors.end(), max_iteration_tensor) -
graph_->allTensors.begin();
// remove maxiteration from body_to_cond partial node
body_to_cond_partial_node_->inputIndex.erase(body_to_cond_partial_node_->inputIndex.begin() + max_iteration_idx);
// concat body subgraph tensor to max iteration in all tensor
auto body_max_iteration_tensor_idx = body_subgraph_input.at(max_iteration_idx);
for (auto &node : cond_graph_nodes_) {
std::replace_if(
node->inputIndex.begin(), node->inputIndex.end(),
[&body_max_iteration_tensor_idx](uint32_t idx) { return idx == body_max_iteration_tensor_idx; },
all_tensor_idx);
}
// remove maxiteration from body partial input and body func input
body_partial_node_->inputIndex.erase(body_partial_node_->inputIndex.begin() + max_iteration_idx);
it = body_subgraph_input.erase(it);
} else {
it++;
}
}
return RET_OK;
}
STATUS SingleSwitchPass::InsertMerge() {
int ret = RET_OK;
auto merge_node = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT);
@ -154,9 +123,9 @@ STATUS SingleSwitchPass::InsertMerge() {
}
// double merge inputs to contain the outputs of body node
for (auto &out_index : origin_switch_output_tensor_indices_) {
auto &switch_out_tensor = graph_->allTensors.at(out_index);
auto tensor = NewTensor(switch_out_tensor);
for (auto &index : cond_partial_node_->inputIndex) {
auto &in_tensor = graph_->allTensors.at(index);
auto tensor = NewTensor(in_tensor);
graph_->allTensors.push_back(std::move(tensor));
merge_node->inputIndex.push_back(graph_->allTensors.size() - 1);
}
@ -266,10 +235,6 @@ STATUS SingleSwitchPass::Init() {
return RET_NULL_PTR;
}
if (switch_node_->inputIndex.size() == kSwitchMinInputSize) {
return RET_OK;
}
if (switch_node_->inputIndex.size() < kSwitchMinInputSize) {
MS_LOG(ERROR) << "switch node: " << switch_node_->name
<< " 's input size is not right, size: " << switch_node_->inputIndex.size();
@ -297,10 +262,6 @@ STATUS SingleSwitchPass::Init() {
}
}
if (cond_partial_node_->primitive->value.type != PrimitiveType_Partial ||
body_partial_node_->primitive->value.type != PrimitiveType_Partial) {
return RET_OK;
}
// get cond_graph_nodes_
cond_subgraph_index_ = cond_partial_node_->primitive->value.AsPartial()->subGraphIndex;
auto cond_node_indices = graph_->subGraph.at(cond_subgraph_index_)->nodeIndices;
@ -330,17 +291,36 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem
return RET_INPUT_PARAM_INVALID;
}
auto &partial_inputs = partial_node->inputIndex;
auto &subgraph_inputs = graph_->subGraph.at(subgraph_index)->inputIndices;
auto &subgraph = graph_->subGraph.at(subgraph_index);
auto &subgraph_inputs = subgraph->inputIndices;
std::map<int, int> subgraph_input_map;
std::vector<int> new_subgraph_inputs{};
std::vector<std::pair<int, int>> tmp_inputs_order{};
for (unsigned int &subgraph_input : subgraph_inputs) {
auto &tensor = graph_->allTensors.at(subgraph_input);
// get parameter input index k. subgraph name + “_input_" + "k"
char k = tensor->name[graph_->subGraph.at(subgraph_index)->name.size() + 7];
int partial_idx = k - '0';
if (tensor->name.size() < subgraph->name.size() + 8) {
MS_LOG(ERROR) << "tensor name: " << tensor->name << " not right.";
return RET_ERROR;
}
int partial_idx = -1;
if (tensor->name.find("_input_") != std::string::npos) {
// get parameter input index k. subgraph name + “_input_" + "k"
auto pos = subgraph->name.size() + sizeof("_input_");
auto pos2 = tensor->name.find('_', pos);
auto idx_str = tensor->name.substr(pos - 1, pos2);
partial_idx = std::stoi(idx_str);
}
if (tensor->name.find("_output_") != std::string::npos) {
// get parameter input index k. subgraph name + “_output_" + "k"
auto pos = subgraph->name.size() + sizeof("_output_");
auto pos2 = tensor->name.find('_', pos);
auto idx_str = tensor->name.substr(pos - 1, pos2);
partial_idx = std::stoi(idx_str);
}
subgraph_input_map.insert(std::pair<int, int>{subgraph_input, partial_inputs[partial_idx]});
new_subgraph_inputs.push_back(partial_inputs[partial_idx]);
tmp_inputs_order.emplace_back(partial_idx, partial_inputs[partial_idx]);
}
for (auto &subgraph_node : subgraph_nodes) {
@ -350,6 +330,13 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem
}
}
}
std::sort(tmp_inputs_order.begin(), tmp_inputs_order.end(),
[](std::pair<int, int> a, std::pair<int, int> b) { return a.first < b.first; });
std::vector<int> new_subgraph_inputs{};
std::transform(tmp_inputs_order.begin(), tmp_inputs_order.end(), std::back_inserter(new_subgraph_inputs),
[](std::pair<int, int> iter) { return iter.second; });
subgraph_inputs.assign(new_subgraph_inputs.begin(), new_subgraph_inputs.end());
return RET_OK;
@ -362,17 +349,28 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche
return RET_INPUT_PARAM_INVALID;
}
auto &partial_outputs = partial_node->outputIndex;
auto &subgraph_outputs = graph_->subGraph.at(subgraph_index)->outputIndices;
auto &subgraph = graph_->subGraph.at(subgraph_index);
auto &subgraph_outputs = subgraph->outputIndices;
std::map<int, int> subgraph_output_map;
std::vector<int> new_subgraph_outputs{};
std::vector<std::pair<int, int>> tmp_outputs_order{};
for (unsigned int &subgraph_output : subgraph_outputs) {
auto &tensor = graph_->allTensors.at(subgraph_output);
// get parameter input index k. subgraph name + “_output_" + "k"
char k = tensor->name[graph_->subGraph.at(subgraph_index)->name.size() + 8];
int partial_idx = k - '0';
subgraph_output_map.insert(std::pair<int, int>{subgraph_output, partial_outputs[partial_idx]});
new_subgraph_outputs.push_back(partial_outputs[partial_idx]);
for (auto &node : subgraph_nodes) {
if (IsContain(node->outputIndex, subgraph_output)) {
int partial_idx = -1;
if (node->name == "LogicalAnd") {
partial_idx = 0;
} else {
// get parameter input index k. subgraph name + “_output_" + "k"
auto pos = subgraph->name.size() + sizeof("_output_");
auto pos2 = node->name.find('_', pos);
auto idx_str = node->name.substr(pos - 1, pos2);
partial_idx = std::stoi(idx_str);
}
subgraph_output_map.insert(std::pair<int, int>{subgraph_output, partial_outputs[partial_idx]});
tmp_outputs_order.emplace_back(partial_idx, partial_outputs[partial_idx]);
}
}
}
for (auto &subgraph_node : subgraph_nodes) {
@ -382,6 +380,10 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche
}
}
}
std::vector<int> new_subgraph_outputs{};
std::transform(tmp_outputs_order.begin(), tmp_outputs_order.end(), std::back_inserter(new_subgraph_outputs),
[](std::pair<int, int> iter) { return iter.second; });
subgraph_outputs.assign(new_subgraph_outputs.begin(), new_subgraph_outputs.end());
return RET_OK;
@ -416,102 +418,6 @@ STATUS SingleSwitchPass::ConcatBodySubgraphInputAndOutput() {
return ret;
}
STATUS SingleSwitchPass::ConvertSwitchToSelect() {
MS_ASSERT(switch_node_->inputIndex.size() >= 3);
MS_ASSERT(switch_node_->inputIndex.size() % 2 != 0);
MS_ASSERT(switch_node_->outputIndex.size() * 2 + 1 == switch_node_->inputIndex.size());
auto bool_index = switch_node_->inputIndex.front();
// insert switch node1
auto switch_node1 = std::make_unique<CNodeT>();
switch_node1->name = switch_node_->name + "-Switch-1";
switch_node1->primitive = std::make_unique<PrimitiveT>();
switch_node1->primitive->value.type = PrimitiveType_Switch;
switch_node1->primitive->value.value = new (std::nothrow) SwitchT();
switch_node1->inputIndex = {bool_index};
std::vector<int> part_one_input_index(
switch_node_->inputIndex.begin() + 1,
switch_node_->inputIndex.begin() + 1 + (switch_node_->inputIndex.size() - 1) / 2);
switch_node1->inputIndex.insert(switch_node1->inputIndex.end(), part_one_input_index.begin(),
part_one_input_index.end());
std::vector<std::unique_ptr<TensorT>> switch_output_tensors1(part_one_input_index.size() * 2);
std::vector<int> switch_output_indexes1(part_one_input_index.size() * 2);
int i = 0;
for (const auto &input_index : part_one_input_index) {
auto &switch_in_tensor = graph_->allTensors.at(input_index);
auto tensor1 = NewTensor(switch_in_tensor);
auto tensor2 = NewTensor(switch_in_tensor);
switch_output_tensors1[i] = std::move(tensor1);
switch_output_tensors1[part_one_input_index.size() + i] = std::move(tensor2);
switch_output_indexes1[i] = graph_->allTensors.size() - 1 + i;
switch_output_indexes1[part_one_input_index.size() + i] =
graph_->allTensors.size() - 1 + i + part_one_input_index.size();
i++;
}
for (auto &tensor : switch_output_tensors1) {
graph_->allTensors.emplace_back(std::move(tensor));
}
switch_node1->outputIndex.insert(switch_node1->outputIndex.begin(), switch_output_indexes1.begin(),
switch_output_indexes1.end());
// insert switch node2
auto switch_node2 = std::make_unique<CNodeT>();
switch_node2->name = switch_node_->name + "-Switch-1";
switch_node2->primitive = std::make_unique<PrimitiveT>();
switch_node2->primitive->value.type = PrimitiveType_Switch;
switch_node2->primitive->value.value = new (std::nothrow) SwitchT();
switch_node2->inputIndex = {bool_index};
std::vector<int> part_two_input_index(
switch_node_->inputIndex.begin() + 1 + (switch_node_->inputIndex.size() - 1) / 2, switch_node_->inputIndex.end());
switch_node2->inputIndex.insert(switch_node2->inputIndex.end(), part_two_input_index.begin(),
part_two_input_index.end());
std::vector<std::unique_ptr<TensorT>> switch_output_tensors2(part_two_input_index.size() * 2);
std::vector<int> switch_output_indexes2(part_two_input_index.size() * 2);
i = 0;
for (const auto &input_index : part_two_input_index) {
auto &switch_in_tensor = graph_->allTensors.at(input_index);
auto tensor1 = NewTensor(switch_in_tensor);
auto tensor2 = NewTensor(switch_in_tensor);
switch_output_tensors2[i] = std::move(tensor1);
switch_output_tensors2[part_two_input_index.size() + i] = std::move(tensor2);
switch_output_indexes2[i] = graph_->allTensors.size() - 1 + i;
switch_output_indexes2[part_two_input_index.size() + i] =
graph_->allTensors.size() - 1 + i + part_two_input_index.size();
i++;
}
for (auto &tensor : switch_output_tensors2) {
graph_->allTensors.emplace_back(std::move(tensor));
}
switch_node2->outputIndex.insert(switch_node2->outputIndex.begin(), switch_output_indexes2.begin(),
switch_output_indexes2.end());
// insert merge
auto merge_node = std::make_unique<CNodeT>();
merge_node->name = switch_node_->name + "-Merge";
merge_node->primitive = std::make_unique<PrimitiveT>();
merge_node->primitive->value.type = PrimitiveType_Merge;
merge_node->primitive->value.value = new (std::nothrow) MergeT();
std::vector<int> merge_input_indexes(switch_node_->outputIndex.size() * 2);
for (i = 0; i < switch_node_->outputIndex.size(); i++) {
merge_input_indexes[i] = switch_output_indexes1[i];
merge_input_indexes[i + switch_node_->outputIndex.size()] =
switch_output_indexes2[i + switch_node_->outputIndex.size()];
merge_node->outputIndex.emplace_back(switch_node_->outputIndex.at(i));
}
merge_node->inputIndex.insert(merge_node->inputIndex.end(), merge_input_indexes.begin(), merge_input_indexes.end());
graph_->nodes.emplace_back(std::move(switch_node1));
graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1);
graph_->nodes.emplace_back(std::move(switch_node2));
graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1);
graph_->nodes.emplace_back(std::move(merge_node));
graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1);
RemoveUselessNode(switch_node_, graph_);
return RET_OK;
}
STATUS SingleSwitchPass::Run() {
int ret = Init();
if (ret != RET_OK) {
@ -519,24 +425,6 @@ STATUS SingleSwitchPass::Run() {
return ret;
}
if (switch_node_->inputIndex.size() == kSwitchMinInputSize) {
return RET_OK;
}
if (cond_partial_node_->primitive->value.type != PrimitiveType_Partial ||
body_partial_node_->primitive->value.type != PrimitiveType_Partial) {
ret = ConvertSwitchToSelect();
return ret;
}
if (IsLoop()) {
ret = MoveMaxIterationToCond();
if (ret != RET_OK) {
MS_LOG(ERROR) << "MoveMaxIterationToCond failed, ret: " << ret;
return ret;
}
}
ret = DoubleSwitchOutput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoubleSwitchOutput failed, ret: " << ret;

View File

@ -45,11 +45,9 @@ class SingleSwitchPass {
STATUS Init();
size_t InitThisGraphIndex();
STATUS DoubleSwitchOutput();
STATUS MoveMaxIterationToCond();
STATUS UpdateSwitchUser();
STATUS ConcatCondSubgraphInputAndOutput();
STATUS ConcatBodySubgraphInputAndOutput();
STATUS ConvertSwitchToSelect();
bool IsLoop();
STATUS InsertMerge();
STATUS UpdateSubgraphInput(const size_t &subgraph_index, schema::CNodeT *partial_node,

View File

@ -27,56 +27,71 @@ namespace mindspore {
namespace lite {
STATUS TopologicalSortPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
std::vector<std::unique_ptr<schema::CNodeT>> newNodes;
std::vector<size_t> sinkedTensorIdxes;
// put all const tensor index into sinkedTensorIdxes
std::vector<std::unique_ptr<schema::CNodeT>> new_nodes;
std::vector<size_t> sinked_tensor_idxes;
// put all const tensor index into sinked_tensor_idxes
for (size_t i = 0; i < graph->allTensors.size(); i++) {
if (graph->allTensors.at(i)->nodeType == schema::NodeType::NodeType_ValueNode) {
sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), i);
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), i);
}
}
auto &oldNodes = graph->nodes;
std::queue<std::unique_ptr<schema::CNodeT>> opQueue;
// put all non depend node into queue
for (auto &node : graph->nodes) {
if (IsNodeNonDepend(node, sinkedTensorIdxes)) {
sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), node->outputIndex.begin(), node->outputIndex.end());
opQueue.push(std::move(node));
}
}
// bfs
while (!opQueue.empty()) {
auto &node = opQueue.front();
auto postNodeIdxes = GetOutputNodeIdx(*graph, *(node.get()));
for (auto postNodeIdx : postNodeIdxes) {
auto &postNode = oldNodes.at(postNodeIdx);
// check if postNode is non-depended
if (IsNodeNonDepend(postNode, sinkedTensorIdxes)) {
sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), postNode->outputIndex.begin(), postNode->outputIndex.end());
opQueue.push(std::move(postNode));
auto &old_nodes = graph->nodes;
std::queue<std::unique_ptr<schema::CNodeT>> op_queue;
// put all none depend node into queue
for (size_t i = 0; i < graph->subGraph.size(); i++) {
std::vector<unsigned int> new_subgraph_node_indices = {};
auto subgraph_node_indices = graph->subGraph[i]->nodeIndices;
for (size_t j = 0; j < subgraph_node_indices.size(); j++) {
auto &node = old_nodes[subgraph_node_indices[j]];
if (IsNodeNonDepend(node, sinked_tensor_idxes)) {
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
op_queue.push(std::move(node));
}
}
newNodes.emplace_back(std::move(node));
opQueue.pop();
while (!op_queue.empty()) {
auto &node = op_queue.front();
auto post_node_idxes = GetOutputNodeIdx(*graph, *(node.get()));
for (auto post_node_idx : post_node_idxes) {
if (IsContain(subgraph_node_indices, (unsigned int)(post_node_idx))) {
auto &post_node = old_nodes.at(post_node_idx);
// check if post_node is non-depended
if (IsNodeNonDepend(post_node, sinked_tensor_idxes)) {
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), post_node->outputIndex.begin(),
post_node->outputIndex.end());
op_queue.push(std::move(post_node));
}
}
}
new_nodes.emplace_back(std::move(node));
new_subgraph_node_indices.push_back(new_nodes.size() - 1);
op_queue.pop();
}
graph->subGraph[i]->nodeIndices.swap(new_subgraph_node_indices);
}
if (newNodes.size() != oldNodes.size()) {
MS_LOG(ERROR) << "Unknow error in TopologicalSort, oldNodesSize: " << oldNodes.size()
<< ", newNodesSize: " << newNodes.size();
if (new_nodes.size() != old_nodes.size()) {
MS_LOG(ERROR) << "Unknow error in TopologicalSort, old_nodes size: " << old_nodes.size()
<< ", new_nodes size: " << new_nodes.size();
return RET_ERROR;
}
graph->nodes.swap(newNodes);
graph->nodes.swap(new_nodes);
return RET_OK;
}
bool TopologicalSortPass::IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node,
const std::vector<size_t> &sinkedTensorIdxes) {
const std::vector<size_t> &sinked_tensor_idxes) {
MS_ASSERT(node != nullptr);
for (auto inputIdx : node->inputIndex) {
if (!IsContain(sinkedTensorIdxes, size_t(inputIdx))) {
return false;
}
if (node->primitive->value.type == schema::PrimitiveType_Merge) {
auto node_input_index = node->inputIndex;
MS_ASSERT(node_input_index.size() % 2 == 0);
return std::all_of(node_input_index.begin(), node_input_index.begin() + node_input_index.size() / 2,
[&](size_t input_idx) { return IsContain(sinked_tensor_idxes, input_idx); }) ||
std::all_of(node_input_index.begin() + node_input_index.size() / 2, node_input_index.end(),
[&](size_t input_idx) { return IsContain(sinked_tensor_idxes, input_idx); });
} else {
return std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
[&](size_t input_idx) { return IsContain(sinked_tensor_idxes, size_t(input_idx)); });
}
return true;
}
} // namespace lite
} // namespace mindspore

View File

@ -54,6 +54,7 @@ FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::s
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph"));
return func_graph_ptr_;
}

View File

@ -80,6 +80,7 @@ FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::st
MS_LOG(ERROR) << "convert graph outputs failed.";
return nullptr;
}
func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph"));
return func_graph_ptr_;
}

View File

@ -61,7 +61,7 @@ STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op,
}
primitive->value.type = schema::PrimitiveType_Mul;
primitive->value.value = attr.release();
} else if (tf_op.op() == "Div") {
} else if (tf_op.op() == "Div" || tf_op.op() == "RealDiv") {
auto attr = std::make_unique<schema::DivT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
@ -154,6 +154,7 @@ TFNodeRegistrar g_tfAddV2Parser("AddV2", new TFArithmeticParser());
TFNodeRegistrar g_tfSubParser("Sub", new TFArithmeticParser());
TFNodeRegistrar g_tfMulParser("Mul", new TFArithmeticParser());
TFNodeRegistrar g_tfDivParser("Div", new TFArithmeticParser());
TFNodeRegistrar g_tfRealDivParser("RealDiv", new TFArithmeticParser());
TFNodeRegistrar g_tfMaximumParser("Maximum", new TFArithmeticParser());
TFNodeRegistrar g_tfMinimumParser("Minimum", new TFArithmeticParser());
TFNodeRegistrar g_tfGreaterParser("Greater", new TFArithmeticParser());

View File

@ -37,10 +37,11 @@ static const std::vector<schema::PrimitiveType> tensorListOutputOpList = {
AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) {
AnfNodePtr ret = nullptr;
if (anf_node_map.find(name) != anf_node_map.end()) {
ret = anf_node_map.at(name);
auto flat_anf_name = TensorFlowUtils::GetFlattenNodeName(name);
if (anf_node_map.find(flat_anf_name) != anf_node_map.end()) {
ret = anf_node_map.at(flat_anf_name);
} else if (anf_node_map.find(name + ":0") != anf_node_map.end()) {
ret = anf_node_map.at(name + ":0");
ret = anf_node_map.at(flat_anf_name + ":0");
}
return ret;
}
@ -212,6 +213,17 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value
if (status != RET_OK) {
return status;
}
} else if (type == kObjectTypeString) {
auto tensor_data = new (std::nothrow) string;
if (tensor_proto.string_val_size() == 1) {
string value = tensor_proto.string_val(0);
*tensor_data = value;
} else {
MS_LOG(ERROR) << "string size bigger than one, not support.";
return RET_ERROR;
}
tensor_size = (*tensor_data).size();
param_value->SetTensorData(tensor_data, tensor_size);
} else {
MS_LOG(ERROR) << "Unsupport dataType: " << type;
return RET_ERROR;
@ -318,6 +330,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
anf_root_graph_->set_attr("graph_name", MakeValue("main_graph"));
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i);
@ -364,7 +377,6 @@ STATUS TFModelParser::ConvertSubgraph() {
std::map<CNodePtr, FuncGraphPtr> while_cond_map;
std::map<CNodePtr, FuncGraphPtr> while_body_map;
for (int i = 0; i < subgraph_size; i++) {
std::vector<ParameterPtr> sub_graph_inputs;
auto &tf_sub_fuction = graph_def_liarary.function(i);
auto &tf_sub_signature = tf_sub_fuction.signature();
auto input_arg_size = tf_sub_signature.input_arg_size();
@ -381,13 +393,17 @@ STATUS TFModelParser::ConvertSubgraph() {
}
FuncGraphPtr sub_func_graph = std::make_shared<FuncGraph>();
sub_func_graph->set_attr("graph_name", MakeValue(sub_graph_name));
std::unordered_map<std::string, AnfNodePtr> anf_sub_node_map;
// convert sub graph inputs
std::vector<ParameterPtr> sub_graph_inputs;
for (int j = 0; j < input_arg_size; j++) {
auto &input_arg = tf_sub_signature.input_arg(j);
auto paramter = sub_func_graph->add_parameter();
paramter->set_name(input_arg.name());
anf_sub_node_map[input_arg.name()] = paramter;
auto root_while_inputs = while_cnode->inputs();
paramter->set_abstract(root_while_inputs[j + 1]->abstract());
sub_graph_inputs.emplace_back(paramter);
}
std::map<std::string, const tensorflow::NodeDef *> tf_sub_node_map;
@ -452,8 +468,19 @@ STATUS TFModelParser::ConvertSubgraph() {
}
// hardcode subgraph inputs name
for (size_t j = 0; j < sub_graph_inputs.size(); j++) {
sub_graph_inputs[j]->set_name("graph" + std::to_string(i) + "_input_" + std::to_string(j) + "parameter");
sub_graph_inputs[j]->set_name(sub_graph_name + "_input_" + std::to_string(j) + "_parameter");
}
// hardcode subgraph outputs name
for (size_t j = 1; j < sub_output_nodes.size(); j++) {
if (utils::isa<CNodePtr>(sub_output_nodes[j])) {
sub_output_nodes[j]->cast<CNodePtr>()->set_fullname_with_scope(sub_graph_name + "_output_" +
std::to_string(j - 1) + "_cnode");
} else if (utils::isa<ParameterPtr>(sub_output_nodes[j])) {
sub_output_nodes[j]->cast<ParameterPtr>()->set_name(sub_graph_name + "_output_" + std::to_string(j - 1) +
"_parameter");
}
}
MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name;
}
auto status = WhileNodePostProcess(while_cond_map, while_body_map);
@ -469,9 +496,8 @@ STATUS TFModelParser::WhileNodePostProcess(const std::map<CNodePtr, FuncGraphPtr
MS_LOG(ERROR) << "while cond body size error";
return RET_ERROR;
}
std::vector<FuncGraphPtr> roots = {anf_root_graph_};
auto root_func_manager = std::make_shared<FuncGraphManager>(roots);
anf_root_graph_->set_manager(root_func_manager);
static auto root_func_manager = Manage(anf_root_graph_);
for (auto &kv : while_cond_map) {
auto while_node = kv.first;
auto &cond_sub_graph = kv.second;
@ -633,6 +659,11 @@ STATUS TFModelParser::ConvertRootGraphOutputs() {
for (auto &pair : tf_root_graph_nodes_) {
for (int i = 0; i < pair.second->input_size(); ++i) {
all_node_inputs.insert(TensorFlowUtils::GetNodeName(pair.second->input(i)));
auto input_name = pair.second->input(i);
if (input_name[0] == '^') {
input_name.erase(0, 1);
}
all_node_inputs.insert(input_name);
}
}
for (auto &pair : tf_root_graph_nodes_) {
@ -644,7 +675,7 @@ STATUS TFModelParser::ConvertRootGraphOutputs() {
auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_);
auto anf_node = GetAnfNode(origin_name, anf_root_node_map_);
if (anf_node == nullptr) {
MS_LOG(ERROR) << "can't find anf node";
MS_LOG(ERROR) << "can't find anf node: " << origin_name;
return RET_ERROR;
}
output_nodes.push_back(anf_node);

View File

@ -0,0 +1,78 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_ragged_range_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFRaggedRangeParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF RaggedRangeParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::RangeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "starts", &attr_value)) {
MS_LOG(ERROR) << "The starts attr should be specified";
return RET_ERROR;
}
attr->start = static_cast<int32_t>(attr_value.i());
if (!TensorFlowUtils::FindAttrValue(tf_op, "limits", &attr_value)) {
MS_LOG(ERROR) << "The limits attr should be specified";
return RET_ERROR;
}
attr->limit = static_cast<int32_t>(attr_value.i());
if (!TensorFlowUtils::FindAttrValue(tf_op, "deltas", &attr_value)) {
MS_LOG(ERROR) << "The deltas attr should be specified";
return RET_ERROR;
}
attr->delta = static_cast<int32_t>(attr_value.i());
primitive->value.type = schema::PrimitiveType_Range;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfRaggedRangeParser("RaggedRange", new TFRaggedRangeParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RAGFED_RANGE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RAGGED_RANGE_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFRaggedRangeParser : public TFNodeParser {
public:
TFRaggedRangeParser() = default;
~TFRaggedRangeParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_

View File

@ -0,0 +1,78 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_range_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFRangeParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF RangeParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::RangeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "start", &attr_value)) {
MS_LOG(ERROR) << "The start attr should be specified";
return RET_ERROR;
}
attr->start = static_cast<int32_t>(attr_value.i());
if (!TensorFlowUtils::FindAttrValue(tf_op, "limit", &attr_value)) {
MS_LOG(ERROR) << "The limit attr should be specified";
return RET_ERROR;
}
attr->limit = static_cast<int32_t>(attr_value.i());
if (!TensorFlowUtils::FindAttrValue(tf_op, "delta", &attr_value)) {
MS_LOG(ERROR) << "The delta attr should be specified";
return RET_ERROR;
}
attr->delta = static_cast<int32_t>(attr_value.i());
primitive->value.type = schema::PrimitiveType_Range;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfRangeParser("Range", new TFRangeParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANGE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANGE_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFRangeParser : public TFNodeParser {
public:
TFRangeParser() = default;
~TFRangeParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_

View File

@ -9,7 +9,7 @@
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WRRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

View File

@ -18,8 +18,8 @@
#include <string>
#include <vector>
#include <string_view>
#include <unordered_map>
#include <regex>
#include <unordered_map>
#include "src/common/log_adapter.h"
#include "schema/inner/model_generated.h"

View File

@ -76,6 +76,7 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
func_graph_->set_attr("graph_name", MakeValue("main_graph"));
return func_graph_;
}

View File

@ -21,7 +21,22 @@
#include "mindspore/lite/src/ops/primitive_c.h"
#include "tools/anf_importer/import_from_meta_graphT.h"
using mindspore::lite::RET_INFER_INVALID;
namespace mindspore::opt {
ParamValueLitePtr NewParamValueLitePtr(lite::Tensor *tensor) {
auto para_value_lite = std::make_shared<ParamValueLite>();
if (para_value_lite == nullptr) {
MS_LOG(ERROR) << "new ParamValueLite failed";
return nullptr;
}
para_value_lite->set_tensor_shape(tensor->shape());
para_value_lite->set_tensor_type(tensor->data_type());
para_value_lite->set_format(tensor->format());
return para_value_lite;
}
abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) {
MS_ASSERT(nullptr != tensor);
std::vector<int> shape(tensor->shape());
@ -33,15 +48,30 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li
MS_LOG(ERROR) << "new AbstractTensor failed";
return nullptr;
}
auto new_value = std::make_shared<ParamValueLite>();
if (new_value == nullptr) {
auto para_value_lite = NewParamValueLitePtr(tensor);
if (para_value_lite == nullptr) {
MS_LOG(ERROR) << "new ParamValueLite failed";
return nullptr;
}
new_value->set_tensor_shape(tensor->shape());
new_value->set_tensor_type(tensor->data_type());
new_value->set_format(tensor->format());
new_abstract->set_value(new_value);
if (type_id == kObjectTypeTensorType) {
auto tensor_list = dynamic_cast<lite::TensorList *>(tensor);
if (tensor_list == nullptr) {
MS_LOG(ERROR) << "cast tensor_list failed";
return nullptr;
}
auto tensor_info = new int[tensor_list->element_shape().size() + 2];
tensor_info[0] = tensor_list->tensors_data_type();
tensor_info[1] = tensor_list->element_shape().size();
for (size_t i = 0; i < tensor_list->element_shape().size(); ++i) {
tensor_info[i + 2] = tensor_list->element_shape()[i];
}
para_value_lite->set_tensor_addr(tensor_info);
para_value_lite->set_tensor_size(tensor_list->element_shape().size() + 2);
}
new_abstract->set_value(para_value_lite);
return new_abstract;
}
@ -121,13 +151,13 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l
}
if (utils::isa<ValueNodePtr>(cnode->input(i))) {
MS_LOG(WARNING) << "input is value node";
MS_LOG(WARNING) << cnode->fullname_with_scope() << "'s input[" << i << "] is value node";
continue;
}
AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, i);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Abstract of CNode is nullptr";
MS_LOG(ERROR) << "Abstract of CNode: " << cnode->fullname_with_scope() << " is nullptr";
return RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
@ -194,7 +224,7 @@ STATUS InferShapePass::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<
MS_ASSERT(output_tensors != nullptr);
auto abstract = cnode->abstract();
if (abstract == nullptr) {
MS_LOG(ERROR) << "abstract is nullptr";
MS_LOG(ERROR) << "node " << cnode->fullname_with_scope() << " abstract is nullptr";
return RET_ERROR;
}
std::vector<TypeId> types;
@ -264,7 +294,62 @@ STATUS InferShapePass::SetCNodeAbstract(const std::vector<lite::Tensor *> &outpu
return RET_OK;
}
int InferShapePass::StrIsContain(const std::vector<std::string> &total, const std::string &aim) {
for (size_t i = 0; i < total.size(); i++) {
if (aim.find(total[i]) != std::string::npos) {
return i;
}
}
return -1;
}
STATUS InferShapePass::SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph) {
// hard code construct input parameter name
std::vector<std::string> inputs_names{};
for (size_t i = 1; i < cnode->inputs().size(); i++) {
inputs_names.emplace_back("_input_" + std::to_string(i - 1) + "_parameter");
}
// copy cnode input to func_graph input
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (utils::isa<ParameterPtr>(node)) {
auto pos = StrIsContain(inputs_names, node->fullname_with_scope());
if (pos != -1) {
auto pnode = utils::cast<ParameterPtr>(node);
auto input_pnode = utils::cast<ParameterPtr>(cnode->input(pos + 1));
MS_ASSERT(pnode != nullptr);
pnode->set_abstract(input_pnode->abstract());
}
}
}
return RET_OK;
}
STATUS InferShapePass::SwitchCNodeInferShape(const CNodePtr &switch_cnode) {
auto body_partial_cnode = switch_cnode->input(2)->cast<CNodePtr>();
MS_ASSERT(body_partial_cnode != nullptr);
auto body_vnode = body_partial_cnode->input(0)->cast<ValueNodePtr>();
MS_ASSERT(body_vnode != nullptr);
auto body_fg = GetValueNode<FuncGraphPtr>(body_vnode);
MS_ASSERT(body_fg != nullptr);
AbstractBasePtrList abstract_list;
auto body_fg_output_cnode = utils::cast<CNodePtr>(body_fg->output());
for (auto &cnode : body_fg_output_cnode->inputs()) {
if (!utils::isa<CNodePtr>(cnode) && !utils::isa<ParameterPtr>(cnode)) {
continue;
}
abstract_list.push_back(cnode->abstract());
}
switch_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
return RET_OK;
}
bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
if (func_graph->has_flag("HasInferShaped")) {
return true;
}
if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) {
MS_LOG(INFO) << "The framework type of model should be tf/tflite.";
return false;
@ -287,8 +372,14 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
auto cnode = node->cast<CNodePtr>();
auto origin_primc = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(cnode->input(0));
if (origin_primc == nullptr) {
MS_LOG(ERROR) << "origin_primc is nullptr";
return false;
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
if (sub_func_graph == nullptr) {
MS_LOG(ERROR) << "node " << node->fullname_with_scope() << "'s origin_primc is nullptr";
return false;
} else {
MS_LOG(WARNING) << "subgraph infer shape invalid.";
return RET_INFER_INVALID;
}
}
auto origin_primt = origin_primc->primitiveT();
if (origin_primt == nullptr) {
@ -296,6 +387,15 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
return false;
}
auto type = GetCNodeType(cnode);
if (type == schema::PrimitiveType_Switch) {
int ret = SwitchCNodeInferShape(cnode);
if (ret != RET_OK) {
MS_LOG(ERROR) << "PartialCNodeInferShape failed.";
return false;
}
}
if ((type == schema::PrimitiveType_TupleGetItem) ||
#ifdef SUPPORT_TRAIN
(type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) ||

View File

@ -41,6 +41,9 @@ class InferShapePass : public Pass {
STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors);
STATUS SetParameterAbstract(const ParameterPtr &parameter);
STATUS SetCNodeAbstract(const std::vector<lite::Tensor *> &output_tensors, const std::shared_ptr<CNode> &cnode);
STATUS SwitchCNodeInferShape(const CNodePtr &cnode);
int StrIsContain(const std::vector<std::string> &total, const std::string &aim);
int SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph);
private:
FmkType fmk_type = lite::converter::FmkType_ONNX;

View File

@ -0,0 +1,181 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/while_pass.h"
#include <vector>
#include <memory>
#include <algorithm>
#include "mindspore/lite/include/errorcode.h"
#include "mindspore/lite/src/ops/primitive_c.h"
#include "tools/anf_importer/import_from_meta_graphT.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "src/ops/primitive_c.h"
#include "schema/inner/model_generated.h"
#include "src/tensor.h"
#include "src/common/log_adapter.h"
#include "src/ops/switch.h"
#include "src/ops/partial.h"
namespace mindspore::opt {
ValueNodePtr WhilePass::GetSwitchAnfPrim() {
auto switch_primitiveT = new (std::nothrow) schema::PrimitiveT;
if (switch_primitiveT == nullptr) {
MS_LOG(ERROR) << "new switch_primitiveT failed";
return nullptr;
}
switch_primitiveT->value.type = schema::PrimitiveType_Switch;
switch_primitiveT->value.value = new (std::nothrow) schema::SwitchT;
if (switch_primitiveT->value.value == nullptr) {
MS_LOG(ERROR) << "new MakeTupleT failed";
return nullptr;
}
auto partial_prim = std::make_shared<lite::Partial>(switch_primitiveT);
ValueNodePtr partial_anf_prim = NewValueNode(partial_prim);
return partial_anf_prim;
}
void WhilePass::ReplaceInput(const std::vector<AnfNodePtr> &node_list, AnfNodePtr new_input_cnode,
std::string para_name) {
for (auto &node : node_list) {
if (utils::isa<CNodePtr>(node)) {
auto cnode = utils::cast<CNodePtr>(node);
for (size_t k = 0; k < cnode->inputs().size(); k++) {
if (!utils::isa<ParameterPtr>(cnode->input(k))) {
continue;
}
auto para_input = utils::cast<ParameterPtr>(cnode->input(k));
if (para_input->name() == para_name) {
cnode->set_input(k, new_input_cnode);
}
}
}
}
}
bool WhilePass::Run(const FuncGraphPtr &graph) {
auto node_list = TopoSort(graph->get_return());
static int count = 0;
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
if (opt::GetCNodeType(node) != schema::PrimitiveType_While) {
continue;
}
auto while_cnode = node->cast<CNodePtr>();
MS_ASSERT(while_cnode != nullptr);
if (while_cnode->inputs().size() < kWhileMinInputSize) {
MS_LOG(ERROR) << "while input is not right.";
return false;
}
// the order is fixed.
auto cond_vnode = while_cnode->input(kWhileCondIndex);
auto body_vnode = while_cnode->input(kWhileBodyIndex);
// body_vnode->cast<ValueNodePtr>()->set_value()
auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_vnode);
auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode);
if (cond_fg == nullptr || body_fg == nullptr) {
MS_LOG(ERROR) << "Get value as func_graph failed.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_FAILED);
return false;
}
// create cond partial cnode
std::vector<AnfNodePtr> cond_partial_op_inputs{cond_vnode};
// create body partial cnode
std::vector<AnfNodePtr> body_partial_op_inputs{body_vnode};
// add while op input to cond_cnode and body_cnode
cond_partial_op_inputs.insert(cond_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
while_cnode->inputs().end());
body_partial_op_inputs.insert(body_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
while_cnode->inputs().end());
static int idx = 0;
auto cond_partial_node = graph->NewCNode(cond_partial_op_inputs);
cond_partial_node->set_fullname_with_scope("Partial-while-cond-" + std::to_string(idx));
cond_partial_node->set_abstract(cond_fg->output()->abstract());
auto body_partial_node = graph->NewCNode(body_partial_op_inputs);
body_partial_node->set_fullname_with_scope("Partial-while-body-" + std::to_string(idx));
idx++;
// concat body_fg output to cond_fg input
auto body_output = body_fg->output();
auto body_output_cnode = utils::cast<CNodePtr>(body_output);
auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(body_output_cnode->input(0));
if (prim == nullptr) {
MS_LOG(ERROR) << "Get PrimitiveC of node:" << body_output_cnode->fullname_with_scope() << " failed.";
return false;
}
// concat body to cond
std::vector<AnfNodePtr> body_to_cond_inputs{cond_vnode};
if ((schema::PrimitiveType)(prim->Type()) == schema::PrimitiveType_MakeTuple) {
for (size_t i = 1; i < body_output_cnode->inputs().size(); ++i) {
body_to_cond_inputs.emplace_back(body_output_cnode->input(i));
}
} else {
body_to_cond_inputs.emplace_back(body_output_cnode->input(1));
}
// concat body to cond
auto body_to_cond_cnode = body_fg->NewCNode(body_to_cond_inputs);
body_to_cond_cnode->set_fullname_with_scope("Partial-while-body-to-cond");
auto body_fg_manager = body_fg->manager();
body_fg_manager->Replace(body_fg->output(), body_to_cond_cnode);
body_fg->set_output(body_to_cond_cnode);
body_partial_node->set_abstract(cond_fg->output()->abstract());
// create switch cnode
ValueNodePtr switch_anf_primitive = GetSwitchAnfPrim();
if (switch_anf_primitive == nullptr) {
MS_LOG(ERROR) << "GetSwitchAnfPrim failed.";
return false;
}
// insert switch node
std::vector<AnfNodePtr> switch_op_inputs = {switch_anf_primitive, cond_partial_node, body_partial_node};
auto switch_cnode = graph->NewCNode(switch_op_inputs);
switch_cnode->set_fullname_with_scope("Switch-" + std::to_string(count++));
AbstractBasePtrList abstract_list;
auto body_fg_output_cnode = utils::cast<CNodePtr>(body_fg->output());
for (auto &cnode : body_fg_output_cnode->inputs()) {
if (!utils::isa<CNodePtr>(cnode) && !utils::isa<ParameterPtr>(cnode)) {
continue;
}
abstract_list.push_back(cnode->abstract());
}
switch_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
// create cond partial cnode
auto manager = graph->manager();
auto node_users = manager->node_users()[while_cnode];
for (auto &node_user : node_users) {
manager->SetEdge(node_user.first, node_user.second, switch_cnode);
}
}
return true;
}
} // namespace mindspore::opt

View File

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WHILE_PASS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WHILE_PASS_H_
#include <string>
#include <vector>
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h"
#include "backend/optimizer/common/pass.h"
#include "src/param_value_lite.h"
using mindspore::lite::converter::FmkType;
namespace mindspore::opt {
class WhilePass : public Pass {
public:
WhilePass() : Pass("while_pass") {}
~WhilePass() override = default;
bool Run(const FuncGraphPtr &graph) override;
private:
void ReplaceInput(const std::vector<AnfNodePtr> &node_list, AnfNodePtr new_input_cnode, std::string para_name);
ValueNodePtr GetSwitchAnfPrim();
const size_t kWhileMinInputSize = 3;
const size_t kWhileCondIndex = 1;
const size_t kWhileBodyIndex = 2;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_