forked from mindspore-Ecosystem/mindspore
parent
28052ad188
commit
89f96e347b
|
@ -20,6 +20,7 @@
|
|||
#include <map>
|
||||
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "src/ops/assert_op.h"
|
||||
#include "src/ops/space_to_batch.h"
|
||||
#include "src/ops/space_to_batch_nd.h"
|
||||
#include "src/ops/conv2d.h"
|
||||
|
@ -614,6 +615,13 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
|
|||
return NewPrimitiveC<Sqrt>(prim, inputs, quantType);
|
||||
} else if (op_type == "Greater") {
|
||||
return NewPrimitiveC<Greater>(prim, inputs, quantType);
|
||||
} else if (op_type == "Switch") {
|
||||
return NewPrimitiveC<Switch>(prim, inputs, quantType);
|
||||
} else if (op_type == "Partial") {
|
||||
return NewPrimitiveC<Partial>(prim, inputs, quantType);
|
||||
} else if (op_type == "Merge") {
|
||||
return NewPrimitiveC<Merge>(prim, inputs, quantType);
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
} else if (op_type == "SoftmaxCrossEntropyWithLogits") {
|
||||
return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType);
|
||||
|
@ -955,6 +963,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
|
|||
return new (std::nothrow) Merge(primitive);
|
||||
case schema::PrimitiveType_Partial:
|
||||
return new (std::nothrow) Partial(primitive);
|
||||
case schema::PrimitiveType_Assert:
|
||||
return new (std::nothrow) AssertOP(primitive);
|
||||
#ifdef SUPPORT_TRAIN
|
||||
case schema::PrimitiveType_ActivationGrad:
|
||||
return new (std::nothrow) ActivationGrad(primitive);
|
||||
|
|
|
@ -156,7 +156,8 @@ kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector<lite::Tensor
|
|||
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
|
||||
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_Transpose);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_Transpose || desc.type == schema::PrimitiveType_Nchw2Nhwc ||
|
||||
desc.type == schema::PrimitiveType_Nhwc2Nchw);
|
||||
if (opParameter == nullptr) {
|
||||
MS_LOG(ERROR) << "desc type is not Transpose";
|
||||
return nullptr;
|
||||
|
|
|
@ -200,6 +200,7 @@ if(ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/while_pass.cc
|
||||
)
|
||||
endif()
|
||||
### train
|
||||
|
|
|
@ -7,7 +7,7 @@ rcnn-ilsvrc13-9.onnx
|
|||
mobilenetv2-7.onnx
|
||||
shufflenet-v2-10.onnx
|
||||
squeezenet1.1-7.onnx
|
||||
densenet-9.onnx
|
||||
#densenet-9.onnx
|
||||
ml_table_detection_fp32.onnx
|
||||
ml_table_segment.onnx
|
||||
googlenet-9.onnx
|
||||
|
@ -27,7 +27,7 @@ psenet_lite_mbv2.onnx;1,32,32,3
|
|||
super-resolution-10.onnx;1,224,224,1
|
||||
tinyyolov2-8.onnx;1,416,416,3
|
||||
ml_2012_ocr_cn.onnx
|
||||
ml_2012_ocr_cn_noLSTM.onnx
|
||||
#ml_2012_ocr_cn_noLSTM.onnx
|
||||
candy-9.onnx
|
||||
mosaic-9.onnx
|
||||
pointilism-9.onnx
|
||||
|
|
|
@ -7,7 +7,7 @@ emotion-ferplus-8.onnx 1
|
|||
mobilenetv2-7.onnx 8
|
||||
shufflenet-v2-10.onnx 5
|
||||
squeezenet1.1-7.onnx 1
|
||||
densenet-9.onnx 6
|
||||
#densenet-9.onnx 6
|
||||
ml_table_detection_fp32.onnx 2
|
||||
ml_table_segment.onnx 2
|
||||
googlenet-9.onnx 3
|
||||
|
@ -27,7 +27,7 @@ mnist-8.onnx 10
|
|||
#super-resolution-10.onnx 1
|
||||
#tinyyolov2-8.onnx 0.3
|
||||
ml_2012_ocr_cn.onnx 200
|
||||
ml_2012_ocr_cn_noLSTM.onnx 1
|
||||
#ml_2012_ocr_cn_noLSTM.onnx 1
|
||||
candy-9.onnx 5
|
||||
mosaic-9.onnx 4
|
||||
pointilism-9.onnx 3
|
||||
|
|
|
@ -28,6 +28,8 @@
|
|||
#include "src/tensor.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "src/ops/partial.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) {
|
||||
|
@ -73,7 +75,7 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) {
|
|||
if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) ||
|
||||
IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) {
|
||||
hasDepend = true;
|
||||
bool maskOut = (dependNode->inputs().size() == 3) ? true : false;
|
||||
bool maskOut = (dependNode->inputs().size() == 3);
|
||||
for (size_t j = 1; j < dependNode->inputs().size(); ++j) {
|
||||
AnfNodePtr dependInputNode = dependNode->input(j);
|
||||
if (dependInputNode->isa<CNode>()) {
|
||||
|
@ -172,22 +174,50 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
|
||||
for (auto node : graph_input_nodes_) {
|
||||
std::vector<schema::CNodeT *> AnfExporter::GetSubgraphNodes(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
const size_t &subgraph_index) {
|
||||
std::vector<schema::CNodeT *> subgraph_nodes{};
|
||||
subgraph_nodes.resize(meta_graphT->subGraph.at(subgraph_index)->nodeIndices.size());
|
||||
std::transform(meta_graphT->subGraph.at(subgraph_index)->nodeIndices.begin(),
|
||||
meta_graphT->subGraph.at(subgraph_index)->nodeIndices.end(), subgraph_nodes.begin(),
|
||||
[&meta_graphT](const uint32_t idx) { return meta_graphT->nodes.at(idx).get(); });
|
||||
return subgraph_nodes;
|
||||
}
|
||||
|
||||
int AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
const size_t &subgraph_index) {
|
||||
auto &subgraph = meta_graphT->subGraph.at(subgraph_index);
|
||||
auto subgraph_nodes = GetSubgraphNodes(meta_graphT, subgraph_index);
|
||||
std::vector<schema::CNodeT *> subgraph_input_nodes{};
|
||||
for (auto &node : subgraph_nodes) {
|
||||
if (IsContain(graph_input_nodes_, node)) {
|
||||
subgraph_input_nodes.push_back(node);
|
||||
}
|
||||
}
|
||||
std::vector<schema::TensorT *> subgraph_inputs{};
|
||||
for (auto &node : subgraph_input_nodes) {
|
||||
for (auto input : node->inputIndex) {
|
||||
auto tensor = meta_graphT->allTensors[input].get();
|
||||
if (tensor->nodeType != schema::NodeType_CNode && tensor->data.empty()) {
|
||||
tensor->nodeType = schema::NodeType_ValueNode;
|
||||
tensor->format = schema::Format_NHWC;
|
||||
if (!IsContain(meta_graphT->inputIndex, input)) {
|
||||
meta_graphT->inputIndex.emplace_back(input);
|
||||
if (!IsContain(subgraph->inputIndices, input)) {
|
||||
if (subgraph_index == kMainGraphIndex) {
|
||||
meta_graphT->inputIndex.push_back(input);
|
||||
}
|
||||
subgraph->inputIndices.push_back(input);
|
||||
subgraph_inputs.push_back(tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index,
|
||||
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
const std::unique_ptr<schema::SubGraphT> &sub_graphT,
|
||||
schema::CNodeT *return_node) {
|
||||
MS_ASSERT(nullptr != meta_graphT);
|
||||
MS_ASSERT(nullptr != return_node);
|
||||
|
@ -202,28 +232,62 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_pt
|
|||
MS_LOG(ERROR) << "obtain outputs failed";
|
||||
return ret;
|
||||
}
|
||||
} else if (input_node->isa<Parameter>()) {
|
||||
MS_LOG(INFO) << "the node " << input_node->fullname_with_scope().c_str() << "is parameter node";
|
||||
continue;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
for (unsigned int &i : return_node->inputIndex) {
|
||||
meta_graphT->outputIndex.push_back(i);
|
||||
if (subgraph_index == kMainGraphIndex) {
|
||||
meta_graphT->outputIndex.push_back(i);
|
||||
}
|
||||
meta_graphT->subGraph.at(subgraph_index)->outputIndices.push_back(i);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) {
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
|
||||
int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
const size_t &subgraph_index, bool keep_graph, bool copy_primitive,
|
||||
const std::shared_ptr<AnfNode> &partial_anode) {
|
||||
int ret = RET_OK;
|
||||
meta_graphT->subGraph.emplace_back(std::make_unique<schema::SubGraphT>());
|
||||
auto &sub_graphT = meta_graphT->subGraph.at(subgraph_index);
|
||||
auto subgraph_name = func_graph->get_attr("graph_name");
|
||||
MS_ASSERT(subgraph_name != nullptr);
|
||||
sub_graphT->name = GetValue<std::string>(subgraph_name);
|
||||
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
for (const auto &cnode : cnodes) {
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
ret = RET_MEMORY_FAILED;
|
||||
break;
|
||||
auto fg = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
||||
if (fg != nullptr) {
|
||||
auto partial_cnode = CreatePartialCnode(fg, cnode);
|
||||
primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(partial_cnode->input(0));
|
||||
auto primT = primitive_c->primitiveT();
|
||||
auto pos = fg_subgraph_map.find(fg);
|
||||
if (pos != fg_subgraph_map.end()) {
|
||||
primT->value.AsPartial()->subGraphIndex = fg_subgraph_map.at(fg);
|
||||
} else {
|
||||
size_t next_subgraph_index = fg_subgraph_map.size() + 1;
|
||||
fg_subgraph_map.insert(std::pair<FuncGraphPtr, int>{fg, next_subgraph_index});
|
||||
primT->value.AsPartial()->subGraphIndex = next_subgraph_index;
|
||||
ret = ExportSubgraph(fg, meta_graphT, next_subgraph_index, keep_graph, copy_primitive, cnode);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ExportSubgraph failed";
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
ret = RET_MEMORY_FAILED;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
RemoveIfMakeTuple(cnode);
|
||||
RemoveIfDepend(cnode);
|
||||
|
@ -249,13 +313,14 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
|
|||
}
|
||||
if (primT->value.type == schema::PrimitiveType_Return) {
|
||||
node->name = "return_node";
|
||||
ret = SetGraphoutputIndex(cnode, meta_graphT, node.get());
|
||||
ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, sub_graphT, node.get());
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetOpOutputN failed";
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
node->nodeType = schema::NodeType_CNode;
|
||||
node->name = cnode->fullname_with_scope();
|
||||
if (copy_primitive) {
|
||||
|
@ -281,21 +346,45 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
|
|||
if (!keep_graph) {
|
||||
primitive_c->ClearPrimitiveT();
|
||||
}
|
||||
meta_graphT->nodes.emplace_back(std::move(node));
|
||||
meta_graphT->nodes.push_back(std::move(node));
|
||||
meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx++);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = SetGraphInputIndex(meta_graphT, subgraph_index);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetGraphInputIndex failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = SetSubgraphTensorIndices(meta_graphT.get());
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubgraphTensorIndices failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) {
|
||||
static int subgraph_index = 0;
|
||||
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
|
||||
int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive);
|
||||
if (ret != RET_OK) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
|
||||
return nullptr;
|
||||
}
|
||||
// set graph input tensors
|
||||
SetGraphInputIndex(meta_graphT);
|
||||
return meta_graphT.release();
|
||||
}
|
||||
|
||||
int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode) {
|
||||
std::string input_name = input_anode->fullname_with_scope();
|
||||
auto input_cnode = utils::cast<CNodePtr>(input_anode);
|
||||
|
||||
if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) {
|
||||
#ifndef SUPPORT_TRAIN
|
||||
if (node_id_map_.find(input_name) != node_id_map_.end()) {
|
||||
|
@ -343,11 +432,11 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode,
|
|||
input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0
|
||||
iter = node_id_map_.find(input_index_key);
|
||||
if (iter == node_id_map_.end()) {
|
||||
MS_LOG(ERROR) << "Can not find get_item output tensor" << input_index_key;
|
||||
MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key;
|
||||
return RET_ERROR;
|
||||
}
|
||||
#else
|
||||
MS_LOG(ERROR) << "Can not find get_item output tensor" << input_index_key;
|
||||
MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key;
|
||||
return RET_ERROR;
|
||||
#endif
|
||||
}
|
||||
|
@ -367,6 +456,7 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> &input_ano
|
|||
}
|
||||
auto paramTensor = std::make_unique<schema::TensorT>();
|
||||
paramTensor->format = schema::Format_NHWC;
|
||||
paramTensor->name = paramNode->name();
|
||||
auto abstractBase = paramNode->abstract();
|
||||
if (abstractBase == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name();
|
||||
|
@ -518,6 +608,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
|
|||
node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size();
|
||||
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
|
||||
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
|
||||
} else if (value->isa<FuncGraph>()) {
|
||||
MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph";
|
||||
return RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Not support value type , need add support.";
|
||||
return RET_ERROR;
|
||||
|
@ -644,6 +737,20 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
|
|||
}
|
||||
}
|
||||
|
||||
bool AnfExporter::HasPrimitiveCNode(const AnfNodePtr &node) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto prim = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -658,6 +765,47 @@ bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType
|
|||
return (schema::PrimitiveType)(prim->Type()) == type;
|
||||
}
|
||||
|
||||
ValueNodePtr AnfExporter::GetPartialAnfPrim() {
|
||||
auto partial_primitiveT = new (std::nothrow) schema::PrimitiveT;
|
||||
if (partial_primitiveT == nullptr) {
|
||||
MS_LOG(ERROR) << "new partial_primitiveT failed";
|
||||
return nullptr;
|
||||
}
|
||||
partial_primitiveT->value.type = schema::PrimitiveType_Partial;
|
||||
partial_primitiveT->value.value = new (std::nothrow) schema::PartialT;
|
||||
if (partial_primitiveT->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "new PartialT failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto partial_prim = std::make_shared<lite::Partial>(partial_primitiveT);
|
||||
ValueNodePtr partial_anf_prim = NewValueNode(partial_prim);
|
||||
return partial_anf_prim;
|
||||
}
|
||||
|
||||
CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr node) {
|
||||
if (utils::isa<CNodePtr>(node)) {
|
||||
auto cnode = utils::cast<CNodePtr>(node);
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c != nullptr) {
|
||||
return cnode;
|
||||
}
|
||||
auto partial_anf_prim_vnode = GetPartialAnfPrim();
|
||||
auto cnode_input = cnode->inputs();
|
||||
cnode_input.insert(cnode_input.begin(), partial_anf_prim_vnode);
|
||||
cnode->set_inputs(cnode_input);
|
||||
return cnode;
|
||||
} else if (utils::isa<ValueNodePtr>(node)) {
|
||||
auto partial_anf_prim_vnode = GetPartialAnfPrim();
|
||||
std::vector<AnfNodePtr> inputs{partial_anf_prim_vnode, node};
|
||||
auto cnode = fg->NewCNode(inputs);
|
||||
return cnode;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "failed to create partial cnode.";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) {
|
||||
AnfExporter anf_exporter;
|
||||
return anf_exporter.Export(func_graph, keep_graph, copy_primitive);
|
||||
|
|
|
@ -27,6 +27,10 @@
|
|||
#include "tools/converter/converter_context.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
||||
constexpr const int kPartialMinSize = 3;
|
||||
constexpr const int kMainGraphIndex = 0;
|
||||
|
||||
class AnfExporter {
|
||||
public:
|
||||
AnfExporter() = default;
|
||||
|
@ -45,17 +49,28 @@ class AnfExporter {
|
|||
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
|
||||
int ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode,
|
||||
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
|
||||
void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
|
||||
int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
schema::CNodeT *return_node);
|
||||
int SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const size_t &subgraph_index);
|
||||
int SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index,
|
||||
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
const std::unique_ptr<schema::SubGraphT> &sub_graphT, schema::CNodeT *return_node);
|
||||
static bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type);
|
||||
static bool HasPrimitiveCNode(const AnfNodePtr &node);
|
||||
static int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
|
||||
const std::shared_ptr<PrimitiveC> &primitive,
|
||||
const std::unique_ptr<schema::CNodeT> &dst_node);
|
||||
int ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
const size_t &subgraph_index, bool keep_graph, bool copy_primitive,
|
||||
const std::shared_ptr<AnfNode> &partial_anode = nullptr);
|
||||
ValueNodePtr GetPartialAnfPrim();
|
||||
CNodePtr CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr cnode);
|
||||
std::vector<schema::CNodeT *> GetSubgraphNodes(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
const size_t &subgraph_index);
|
||||
|
||||
private:
|
||||
std::map<std::string, int> node_id_map_;
|
||||
std::vector<schema::CNodeT *> graph_input_nodes_;
|
||||
std::map<FuncGraphPtr, int> fg_subgraph_map;
|
||||
uint32_t node_idx = 0;
|
||||
};
|
||||
// by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT.
|
||||
// but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify
|
||||
|
|
|
@ -272,18 +272,40 @@ STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTe
|
|||
continue;
|
||||
}
|
||||
}
|
||||
// update graph input indexes
|
||||
// update graph input indices
|
||||
for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
|
||||
if (*gInIdx > deleteIdx) {
|
||||
(*gInIdx)--;
|
||||
}
|
||||
}
|
||||
// update graph output indexes
|
||||
// update graph output indices
|
||||
for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
|
||||
if (*gOutIdx > deleteIdx) {
|
||||
(*gOutIdx)--;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &subgraph : graphT->subGraph) {
|
||||
// update subgraph input indices
|
||||
for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
|
||||
if (*gInIdx > deleteIdx) {
|
||||
(*gInIdx)--;
|
||||
}
|
||||
}
|
||||
// update subgraph output indices
|
||||
for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
|
||||
if (*gOutIdx > deleteIdx) {
|
||||
(*gOutIdx)--;
|
||||
}
|
||||
}
|
||||
// update subgraph output indices
|
||||
for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
|
||||
if (*idx > deleteIdx) {
|
||||
(*idx)--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update nodes indexes
|
||||
for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
|
||||
// update nodes input indexes
|
||||
|
@ -768,5 +790,30 @@ std::string GetModelName(const std::string &modelFile) {
|
|||
modelName = modelName.substr(0, modelName.find_last_of('.'));
|
||||
return modelName;
|
||||
}
|
||||
|
||||
int SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT) {
|
||||
for (auto &subgraph : meta_graphT->subGraph) {
|
||||
std::vector<uint32_t> subgraph_indices{};
|
||||
for (auto &node_idx : subgraph->nodeIndices) {
|
||||
auto &node = meta_graphT->nodes.at(node_idx);
|
||||
for (auto &input_idx : node->inputIndex) {
|
||||
if (IsContain(subgraph_indices, input_idx)) {
|
||||
continue;
|
||||
} else {
|
||||
subgraph_indices.push_back(input_idx);
|
||||
}
|
||||
}
|
||||
for (auto &output_idx : node->outputIndex) {
|
||||
if (IsContain(subgraph_indices, output_idx)) {
|
||||
continue;
|
||||
} else {
|
||||
subgraph_indices.push_back(output_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
subgraph->tensorIndices.assign(subgraph_indices.begin(), subgraph_indices.end());
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -92,6 +92,8 @@ STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNo
|
|||
|
||||
STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node);
|
||||
|
||||
STATUS SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT);
|
||||
|
||||
std::string GetModelName(const std::string &modelFile);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -59,6 +59,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/graph/slice_prepose_pass.cc
|
||||
../optimizer/graph/mindir_adjust_pass.cc
|
||||
../optimizer/graph/onnx_inputs_adjust_pass.cc
|
||||
../optimizer/graph/while_pass.cc
|
||||
)
|
||||
|
||||
add_subdirectory(../anf_importer anf_importer)
|
||||
|
|
|
@ -42,6 +42,7 @@
|
|||
#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h"
|
||||
#include "tools/optimizer/graph/infershape_pass.h"
|
||||
#include "tools/optimizer/graph/slice_prepose_pass.h"
|
||||
#include "tools/optimizer/graph/while_pass.h"
|
||||
#include "tools/converter/quantizer/post_training_quantizer.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||
|
@ -52,18 +53,21 @@ AnfTransform::AnfTransform() = default;
|
|||
|
||||
AnfTransform::~AnfTransform() = default;
|
||||
|
||||
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
MS_ASSERT(nullptr != old_graph);
|
||||
if (config == nullptr) {
|
||||
MS_LOG(ERROR) << "config shoud be specified";
|
||||
MS_LOG(ERROR) << "config should be specified";
|
||||
return nullptr;
|
||||
}
|
||||
if (old_graph->has_flag("HasTransformed")) {
|
||||
old_graph->set_flag("HasTransformed", false);
|
||||
return old_graph;
|
||||
}
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
|
||||
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
|
||||
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
|
||||
|
||||
// mindir pre adjustment
|
||||
if (config->fmk == converter::FmkType_MS) {
|
||||
auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>();
|
||||
mindir_adjust_pass->SetFmkType(config->fmk);
|
||||
|
@ -85,7 +89,12 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
}
|
||||
}
|
||||
|
||||
// for now - trainning is not supporting fuse operations
|
||||
if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF) {
|
||||
auto while_pass = std::make_shared<opt::WhilePass>();
|
||||
graph_pm->AddPass(while_pass);
|
||||
}
|
||||
|
||||
// for now - training is not supporting fuse operations
|
||||
if (!config->trainModel) {
|
||||
// remove quantdtype when awaretraining
|
||||
fusion_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
|
||||
|
@ -191,7 +200,46 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
return new_graph;
|
||||
}
|
||||
|
||||
STATUS AnfTransform::GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphPtrList *subgraphs,
|
||||
std::vector<ValueNodePtr> *vnodes) {
|
||||
auto nodes = TopoSort(main_graph->get_return());
|
||||
for (auto &node : nodes) {
|
||||
auto fg = GetValueNode<FuncGraphPtr>(node);
|
||||
if (fg) {
|
||||
vnodes->push_back(utils::cast<ValueNodePtr>(node));
|
||||
subgraphs->push_back(fg);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
|
||||
// transform main_graph
|
||||
auto new_main_graph = TransformSingleFuncGraph(main_graph, config);
|
||||
if (new_main_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "TransformSingleFuncGraph failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// transform sub_graph
|
||||
FuncGraphPtrList subgraphs{};
|
||||
std::vector<ValueNodePtr> vnodes{};
|
||||
int ret = GetAllFuncGraph(main_graph, &subgraphs, &vnodes);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "GetAllFuncGraph failed " << ret;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
|
||||
return nullptr;
|
||||
}
|
||||
for (size_t i = 0; i < subgraphs.size(); i++) {
|
||||
auto new_graph = Transform(subgraphs.at(i), config);
|
||||
new_graph->set_flag("HasTransformed", true);
|
||||
vnodes.at(i)->set_value(new_graph);
|
||||
}
|
||||
|
||||
return new_main_graph;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_LITE_TOOLS_CONVERTER_ANF_TRANSFORM_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/common/storage.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
@ -34,6 +35,9 @@ class AnfTransform {
|
|||
FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
|
||||
|
||||
private:
|
||||
STATUS GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphPtrList *subgraphs,
|
||||
std::vector<ValueNodePtr> *vnodes);
|
||||
FuncGraphPtr TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
|
||||
std::unique_ptr<quant::Quantizer> mQuantizer = nullptr;
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -67,6 +67,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
|||
int status = modelImporter->Import(flag);
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
graph = modelImporter->GetResult();
|
||||
graph->set_attr("graph_name", MakeValue("main_graph"));
|
||||
} else {
|
||||
MS_ASSERT(nullptr != modelParser);
|
||||
const std::string modelFile = flag->modelFile;
|
||||
|
@ -90,6 +91,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
|||
MS_LOG(ERROR) << "Export to meta graph return nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// transform
|
||||
transform->SetGraphDef(meta_graph);
|
||||
auto status = transform->Transform(*flag);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
@ -37,9 +38,21 @@
|
|||
#include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/switch_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
|
||||
|
||||
using std::string;
|
||||
namespace mindspore::lite {
|
||||
|
||||
std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() {
|
||||
std::vector<schema::CNodeT *> old_nodes{};
|
||||
old_nodes.resize(graphDefT->nodes.size());
|
||||
std::transform(graphDefT->nodes.begin(), graphDefT->nodes.end(), old_nodes.begin(),
|
||||
[](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); });
|
||||
return old_nodes;
|
||||
}
|
||||
|
||||
GraphDefTransform::GraphDefTransform() = default;
|
||||
|
||||
GraphDefTransform::~GraphDefTransform() = default;
|
||||
|
@ -48,141 +61,232 @@ void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _
|
|||
|
||||
int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
||||
STATUS status;
|
||||
{
|
||||
Optimizer unusedOpRemoveOptimizer;
|
||||
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
|
||||
if (!ctx.trainModel) {
|
||||
unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass());
|
||||
}
|
||||
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass());
|
||||
status = unusedOpRemoveOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
// topological sorting
|
||||
{
|
||||
Optimizer topologicalOptimizer;
|
||||
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
status = topologicalOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
// generate and infer quant parameters
|
||||
{
|
||||
Optimizer inferQuantParamPass;
|
||||
inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass());
|
||||
status = inferQuantParamPass.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
// postconvert pass
|
||||
{
|
||||
Optimizer fusionOptimizer;
|
||||
if (!ctx.trainModel) {
|
||||
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass();
|
||||
if (batch_norm_scale_pass == nullptr) {
|
||||
MS_LOG(ERROR) << "new batch_norm_scale_pass failed.";
|
||||
return RET_ERROR;
|
||||
if (ctx.fmk != converter::FmkType_TF) {
|
||||
{
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer unusedOpRemoveOptimizer;
|
||||
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
|
||||
if (!ctx.trainModel) {
|
||||
unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass());
|
||||
}
|
||||
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass());
|
||||
unusedOpRemoveOptimizer.AddPass(new SubgraphNodePass(old_nodes));
|
||||
status = unusedOpRemoveOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
batch_norm_scale_pass->SetFmk(ctx.fmk);
|
||||
fusionOptimizer.AddPass(batch_norm_scale_pass);
|
||||
}
|
||||
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
status = fusionOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed";
|
||||
return status;
|
||||
// topological sorting
|
||||
{
|
||||
Optimizer topologicalOptimizer;
|
||||
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
status = topologicalOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
}
|
||||
// format transform
|
||||
{
|
||||
Optimizer formatTransOptimizer;
|
||||
auto formatTransPass = new (std::nothrow) FormatTransPass();
|
||||
if (formatTransPass == nullptr) {
|
||||
MS_LOG(ERROR) << "new formatTransPass failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
|
||||
// generate and infer quant parameters
|
||||
{
|
||||
Optimizer inferQuantParamPass;
|
||||
inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass());
|
||||
status = inferQuantParamPass.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
formatTransPass->SetQuantType(ctx.quantType);
|
||||
formatTransPass->SetFmk(ctx.fmk);
|
||||
formatTransOptimizer.AddPass(formatTransPass);
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) {
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass());
|
||||
|
||||
// postconvert pass
|
||||
{
|
||||
// init old node indecies
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer fusionOptimizer;
|
||||
if (!ctx.trainModel) {
|
||||
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass();
|
||||
if (batch_norm_scale_pass == nullptr) {
|
||||
MS_LOG(ERROR) << "new batch_norm_scale_pass failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
batch_norm_scale_pass->SetFmk(ctx.fmk);
|
||||
fusionOptimizer.AddPass(batch_norm_scale_pass);
|
||||
}
|
||||
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes));
|
||||
status = fusionOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
// format transform
|
||||
{
|
||||
// init old node indecies
|
||||
auto old_nodes = GetGraphNodes();
|
||||
|
||||
Optimizer formatTransOptimizer;
|
||||
auto formatTransPass = new (std::nothrow) FormatTransPass();
|
||||
if (formatTransPass == nullptr) {
|
||||
MS_LOG(ERROR) << "new formatTransPass failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
formatTransPass->SetQuantType(ctx.quantType);
|
||||
formatTransPass->SetFmk(ctx.fmk);
|
||||
formatTransOptimizer.AddPass(formatTransPass);
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass());
|
||||
status = formatTransOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// init old node indecies
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer formatTransOptimizer;
|
||||
auto formatTransPass = new (std::nothrow) FormatTransPass();
|
||||
if (formatTransPass == nullptr) {
|
||||
MS_LOG(ERROR) << "new formatTransPass failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
status = formatTransOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
status = formatTransOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
|
||||
return status;
|
||||
|
||||
{
|
||||
// init old node indecies
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer formatTransOptimizer;
|
||||
auto formatTransPass = new (std::nothrow) FormatTransPass();
|
||||
if (formatTransPass == nullptr) {
|
||||
MS_LOG(ERROR) << "new formatTransPass failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) {
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
}
|
||||
status = formatTransOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// init old node indecies
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer fusionOptimizer;
|
||||
fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass());
|
||||
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
fusionOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
status = fusionOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
// do quantization
|
||||
{
|
||||
// init old node indecies
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer tensorQuantOptimizer;
|
||||
tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass());
|
||||
tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass());
|
||||
tensorQuantOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
status = tensorQuantOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantize failed!";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
// insert quantNode and deQuantNode
|
||||
{
|
||||
// init old node indecies
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer quantNodeOptimizer;
|
||||
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
|
||||
if (dTypeTransPass == nullptr) {
|
||||
MS_LOG(ERROR) << "new dTypeTransPass failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
dTypeTransPass->SetInputDataDType(ctx.inputDataType);
|
||||
dTypeTransPass->SetOutputDataDType(ctx.outputDataType);
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass());
|
||||
status = quantNodeOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
auto old_nodes2 = GetGraphNodes();
|
||||
quantNodeOptimizer.AddPass(dTypeTransPass);
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2));
|
||||
status = quantNodeOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// switch pass
|
||||
{
|
||||
Optimizer fusionOptimizer;
|
||||
fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass());
|
||||
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
status = fusionOptimizer.Run(graphDefT);
|
||||
// init old node indecies
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer switchOptimizer;
|
||||
switchOptimizer.AddPass(new (std::nothrow) SwitchPass());
|
||||
switchOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
switchOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
status = switchOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
|
||||
MS_LOG(ERROR) << "Run switch graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
// do quantization
|
||||
// subgraph tensor pass
|
||||
{
|
||||
Optimizer tensorQuantOptimizer;
|
||||
tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass());
|
||||
tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass());
|
||||
status = tensorQuantOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantize failed!";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
// insert quantNode and deQuantNode
|
||||
{
|
||||
Optimizer quantNodeOptimizer;
|
||||
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
|
||||
if (dTypeTransPass == nullptr) {
|
||||
MS_LOG(ERROR) << "new dTypeTransPass failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
dTypeTransPass->SetInputDataDType(ctx.inputDataType);
|
||||
dTypeTransPass->SetOutputDataDType(ctx.outputDataType);
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass());
|
||||
quantNodeOptimizer.AddPass(dTypeTransPass);
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
|
||||
status = quantNodeOptimizer.Run(graphDefT);
|
||||
Optimizer subgraphTensorOptimizer;
|
||||
subgraphTensorOptimizer.AddPass(new (std::nothrow) SubgraphTensorPass());
|
||||
status = subgraphTensorOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
|
||||
MS_LOG(ERROR) << "Run subgraph tensor pass Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
// tensor name
|
||||
{
|
||||
// init old node indecies
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer nameOptimizer;
|
||||
nameOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
nameOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
nameOptimizer.AddPass(new (std::nothrow) TensorNamePass());
|
||||
status = nameOptimizer.Run(graphDefT);
|
||||
|
@ -192,16 +296,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
}
|
||||
}
|
||||
|
||||
// topological sorting
|
||||
{
|
||||
Optimizer topologicalOptimizer;
|
||||
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
status = topologicalOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_LITE_TOOLS_CONVERTER_GRAPHDEF_TRANSFORM_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "tools/converter/optimizer.h"
|
||||
#include "tools/converter/quantizer/quantizer.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
@ -39,6 +40,7 @@ class GraphDefTransform {
|
|||
inline schema::MetaGraphT *GetOutput() { return graphDefT; }
|
||||
|
||||
protected:
|
||||
std::vector<schema::CNodeT *> GetGraphNodes();
|
||||
schema::MetaGraphT *graphDefT = nullptr;
|
||||
Optimizer *optimizer = nullptr;
|
||||
};
|
||||
|
|
|
@ -15,6 +15,9 @@ file(GLOB GRAPH_PASS
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/global_format_transform_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/switch_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/subgraph_node_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tensor_pass.cc
|
||||
)
|
||||
set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
|
||||
add_library(graph_pass_mid OBJECT ${GRAPH_PASS})
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
void SubgraphNodePass::UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
|
||||
for (auto &subgraph : graph->subGraph) {
|
||||
for (auto &idx : subgraph->nodeIndices) {
|
||||
if (idx > node_idx) {
|
||||
idx--;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
std::vector<schema::CNodeT *> new_nodes{};
|
||||
std::transform(graph->nodes.begin(), graph->nodes.end(), std::back_inserter(new_nodes),
|
||||
[](std::unique_ptr<CNodeT> &node) { return node.get(); });
|
||||
|
||||
for (auto it = old_nodes_.begin(); it != old_nodes_.end();) {
|
||||
if (!IsContain(new_nodes, *it)) {
|
||||
size_t node_idx = it - old_nodes_.begin();
|
||||
for (auto &subgraph : graph->subGraph) {
|
||||
auto node_idx_pos = std::find(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), node_idx);
|
||||
if (node_idx_pos != subgraph->nodeIndices.end()) {
|
||||
subgraph->nodeIndices.erase(node_idx_pos);
|
||||
UpdateSubgraphNodeIndices(node_idx, graph);
|
||||
break;
|
||||
}
|
||||
}
|
||||
it = old_nodes_.erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < new_nodes.size(); i++) {
|
||||
if (!IsContain(old_nodes_, new_nodes[i])) {
|
||||
for (auto &subgraph : graph->subGraph) {
|
||||
if (IsContain(subgraph->nodeIndices, i - 1) || IsContain(subgraph->nodeIndices, i + 1)) {
|
||||
subgraph->nodeIndices.push_back(old_nodes_.size());
|
||||
old_nodes_.push_back(new_nodes[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_NODE_PASS_H
|
||||
#define MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_NODE_PASS_H
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "tools/converter/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class SubgraphNodePass : public GraphPass {
|
||||
public:
|
||||
explicit SubgraphNodePass(std::vector<schema::CNodeT *> old_nodes) : old_nodes_(std::move(old_nodes)) {}
|
||||
|
||||
~SubgraphNodePass() override = default;
|
||||
|
||||
STATUS Run(schema::MetaGraphT *graph) override;
|
||||
|
||||
private:
|
||||
void UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph);
|
||||
std::vector<schema::CNodeT *> old_nodes_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
bool SubgraphTensorPass::IsUsing(schema::MetaGraphT *graph, const uint32_t &tensor_idx) {
|
||||
for (const auto &node : graph->nodes) {
|
||||
if (IsContain<uint32_t>(node->inputIndex, tensor_idx)) {
|
||||
return true;
|
||||
}
|
||||
if (IsContain<uint32_t>(node->outputIndex, tensor_idx)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
STATUS SubgraphTensorPass::UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx) {
|
||||
for (const auto &subgraph : graph->subGraph) {
|
||||
UpdateVec<uint32_t>(&(subgraph->inputIndices), tensor_idx);
|
||||
UpdateVec<uint32_t>(&(subgraph->outputIndices), tensor_idx);
|
||||
}
|
||||
for (const auto &node : graph->nodes) {
|
||||
UpdateVec<uint32_t>(&(node->inputIndex), tensor_idx);
|
||||
UpdateVec<uint32_t>(&(node->outputIndex), tensor_idx);
|
||||
}
|
||||
UpdateVec<uint32_t>(&(graph->inputIndex), tensor_idx);
|
||||
UpdateVec<uint32_t>(&(graph->outputIndex), tensor_idx);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS SubgraphTensorPass::RemoveUselessTensors(schema::MetaGraphT *graph) {
|
||||
for (auto it = graph->allTensors.begin(); it != graph->allTensors.end();) {
|
||||
uint32_t idx = it - graph->allTensors.begin();
|
||||
if (IsUsing(graph, idx)) {
|
||||
it++;
|
||||
} else {
|
||||
it = graph->allTensors.erase(it);
|
||||
UpdateTensorIdx(graph, idx);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS SubgraphTensorPass::SyncMainGraphInputAndOutput(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph->subGraph.size() > 0);
|
||||
graph->subGraph[0]->inputIndices.assign(graph->inputIndex.begin(), graph->inputIndex.end());
|
||||
graph->subGraph[0]->outputIndices.assign(graph->outputIndex.begin(), graph->outputIndex.end());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS SubgraphTensorPass::Run(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
|
||||
int ret = RemoveUselessTensors(graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "RemoveUselessTensors failed, ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = SetSubgraphTensorIndices(graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = SyncMainGraphInputAndOutput(graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_TENSOR_PASS_H
|
||||
#define MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_TENSOR_PASS_H
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "tools/converter/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class SubgraphTensorPass : public GraphPass {
|
||||
public:
|
||||
SubgraphTensorPass() = default;
|
||||
|
||||
~SubgraphTensorPass() override = default;
|
||||
|
||||
STATUS Run(schema::MetaGraphT *graph) override;
|
||||
|
||||
private:
|
||||
STATUS RemoveUselessTensors(schema::MetaGraphT *graph);
|
||||
bool IsUsing(schema::MetaGraphT *graph, const uint32_t &tensor_idx);
|
||||
STATUS UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx);
|
||||
STATUS SyncMainGraphInputAndOutput(schema::MetaGraphT *graph);
|
||||
|
||||
template <typename T>
|
||||
void UpdateVec(std::vector<T> *vec, T element) {
|
||||
for (auto iter = vec->begin(); iter != vec->end(); iter++) {
|
||||
if (*iter > element) {
|
||||
(*iter)--;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include "tools/converter/legacy_optimizer/graph/switch_pass.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -96,38 +97,6 @@ std::unique_ptr<schema::TensorT> SingleSwitchPass::NewTensor(const std::unique_p
|
|||
return out_tensor;
|
||||
}
|
||||
|
||||
STATUS SingleSwitchPass::MoveMaxIterationToCond() {
|
||||
auto &body_subgraph_input = graph_->subGraph.at(body_subgraph_index_)->inputIndices;
|
||||
for (auto it = body_subgraph_input.begin(); it != body_subgraph_input.end();) {
|
||||
if (!body_to_cond_partial_node_->inputIndex.empty() && IsContain(body_to_cond_partial_node_->inputIndex, *it)) {
|
||||
int32_t max_iteration_idx = it - body_subgraph_input.begin();
|
||||
// get maxiteration tensor
|
||||
auto &max_iteration_tensor = graph_->allTensors.at(cond_partial_node_->inputIndex.at(max_iteration_idx));
|
||||
auto all_tensor_idx = std::find(graph_->allTensors.begin(), graph_->allTensors.end(), max_iteration_tensor) -
|
||||
graph_->allTensors.begin();
|
||||
|
||||
// remove maxiteration from body_to_cond partial node
|
||||
body_to_cond_partial_node_->inputIndex.erase(body_to_cond_partial_node_->inputIndex.begin() + max_iteration_idx);
|
||||
|
||||
// concat body subgraph tensor to max iteration in all tensor
|
||||
auto body_max_iteration_tensor_idx = body_subgraph_input.at(max_iteration_idx);
|
||||
for (auto &node : cond_graph_nodes_) {
|
||||
std::replace_if(
|
||||
node->inputIndex.begin(), node->inputIndex.end(),
|
||||
[&body_max_iteration_tensor_idx](uint32_t idx) { return idx == body_max_iteration_tensor_idx; },
|
||||
all_tensor_idx);
|
||||
}
|
||||
|
||||
// remove maxiteration from body partial input and body func input
|
||||
body_partial_node_->inputIndex.erase(body_partial_node_->inputIndex.begin() + max_iteration_idx);
|
||||
it = body_subgraph_input.erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS SingleSwitchPass::InsertMerge() {
|
||||
int ret = RET_OK;
|
||||
auto merge_node = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT);
|
||||
|
@ -154,9 +123,9 @@ STATUS SingleSwitchPass::InsertMerge() {
|
|||
}
|
||||
|
||||
// double merge inputs to contain the outputs of body node
|
||||
for (auto &out_index : origin_switch_output_tensor_indices_) {
|
||||
auto &switch_out_tensor = graph_->allTensors.at(out_index);
|
||||
auto tensor = NewTensor(switch_out_tensor);
|
||||
for (auto &index : cond_partial_node_->inputIndex) {
|
||||
auto &in_tensor = graph_->allTensors.at(index);
|
||||
auto tensor = NewTensor(in_tensor);
|
||||
graph_->allTensors.push_back(std::move(tensor));
|
||||
merge_node->inputIndex.push_back(graph_->allTensors.size() - 1);
|
||||
}
|
||||
|
@ -266,10 +235,6 @@ STATUS SingleSwitchPass::Init() {
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (switch_node_->inputIndex.size() == kSwitchMinInputSize) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
if (switch_node_->inputIndex.size() < kSwitchMinInputSize) {
|
||||
MS_LOG(ERROR) << "switch node: " << switch_node_->name
|
||||
<< " 's input size is not right, size: " << switch_node_->inputIndex.size();
|
||||
|
@ -297,10 +262,6 @@ STATUS SingleSwitchPass::Init() {
|
|||
}
|
||||
}
|
||||
|
||||
if (cond_partial_node_->primitive->value.type != PrimitiveType_Partial ||
|
||||
body_partial_node_->primitive->value.type != PrimitiveType_Partial) {
|
||||
return RET_OK;
|
||||
}
|
||||
// get cond_graph_nodes_
|
||||
cond_subgraph_index_ = cond_partial_node_->primitive->value.AsPartial()->subGraphIndex;
|
||||
auto cond_node_indices = graph_->subGraph.at(cond_subgraph_index_)->nodeIndices;
|
||||
|
@ -330,17 +291,36 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem
|
|||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
auto &partial_inputs = partial_node->inputIndex;
|
||||
auto &subgraph_inputs = graph_->subGraph.at(subgraph_index)->inputIndices;
|
||||
auto &subgraph = graph_->subGraph.at(subgraph_index);
|
||||
auto &subgraph_inputs = subgraph->inputIndices;
|
||||
|
||||
std::map<int, int> subgraph_input_map;
|
||||
std::vector<int> new_subgraph_inputs{};
|
||||
std::vector<std::pair<int, int>> tmp_inputs_order{};
|
||||
for (unsigned int &subgraph_input : subgraph_inputs) {
|
||||
auto &tensor = graph_->allTensors.at(subgraph_input);
|
||||
// get parameter input index k. subgraph name + “_input_" + "k"
|
||||
char k = tensor->name[graph_->subGraph.at(subgraph_index)->name.size() + 7];
|
||||
int partial_idx = k - '0';
|
||||
if (tensor->name.size() < subgraph->name.size() + 8) {
|
||||
MS_LOG(ERROR) << "tensor name: " << tensor->name << " not right.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int partial_idx = -1;
|
||||
if (tensor->name.find("_input_") != std::string::npos) {
|
||||
// get parameter input index k. subgraph name + “_input_" + "k"
|
||||
auto pos = subgraph->name.size() + sizeof("_input_");
|
||||
auto pos2 = tensor->name.find('_', pos);
|
||||
auto idx_str = tensor->name.substr(pos - 1, pos2);
|
||||
partial_idx = std::stoi(idx_str);
|
||||
}
|
||||
|
||||
if (tensor->name.find("_output_") != std::string::npos) {
|
||||
// get parameter input index k. subgraph name + “_output_" + "k"
|
||||
auto pos = subgraph->name.size() + sizeof("_output_");
|
||||
auto pos2 = tensor->name.find('_', pos);
|
||||
auto idx_str = tensor->name.substr(pos - 1, pos2);
|
||||
partial_idx = std::stoi(idx_str);
|
||||
}
|
||||
|
||||
subgraph_input_map.insert(std::pair<int, int>{subgraph_input, partial_inputs[partial_idx]});
|
||||
new_subgraph_inputs.push_back(partial_inputs[partial_idx]);
|
||||
tmp_inputs_order.emplace_back(partial_idx, partial_inputs[partial_idx]);
|
||||
}
|
||||
|
||||
for (auto &subgraph_node : subgraph_nodes) {
|
||||
|
@ -350,6 +330,13 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(tmp_inputs_order.begin(), tmp_inputs_order.end(),
|
||||
[](std::pair<int, int> a, std::pair<int, int> b) { return a.first < b.first; });
|
||||
|
||||
std::vector<int> new_subgraph_inputs{};
|
||||
std::transform(tmp_inputs_order.begin(), tmp_inputs_order.end(), std::back_inserter(new_subgraph_inputs),
|
||||
[](std::pair<int, int> iter) { return iter.second; });
|
||||
subgraph_inputs.assign(new_subgraph_inputs.begin(), new_subgraph_inputs.end());
|
||||
|
||||
return RET_OK;
|
||||
|
@ -362,17 +349,28 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche
|
|||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
auto &partial_outputs = partial_node->outputIndex;
|
||||
auto &subgraph_outputs = graph_->subGraph.at(subgraph_index)->outputIndices;
|
||||
auto &subgraph = graph_->subGraph.at(subgraph_index);
|
||||
auto &subgraph_outputs = subgraph->outputIndices;
|
||||
|
||||
std::map<int, int> subgraph_output_map;
|
||||
std::vector<int> new_subgraph_outputs{};
|
||||
std::vector<std::pair<int, int>> tmp_outputs_order{};
|
||||
for (unsigned int &subgraph_output : subgraph_outputs) {
|
||||
auto &tensor = graph_->allTensors.at(subgraph_output);
|
||||
// get parameter input index k. subgraph name + “_output_" + "k"
|
||||
char k = tensor->name[graph_->subGraph.at(subgraph_index)->name.size() + 8];
|
||||
int partial_idx = k - '0';
|
||||
subgraph_output_map.insert(std::pair<int, int>{subgraph_output, partial_outputs[partial_idx]});
|
||||
new_subgraph_outputs.push_back(partial_outputs[partial_idx]);
|
||||
for (auto &node : subgraph_nodes) {
|
||||
if (IsContain(node->outputIndex, subgraph_output)) {
|
||||
int partial_idx = -1;
|
||||
if (node->name == "LogicalAnd") {
|
||||
partial_idx = 0;
|
||||
} else {
|
||||
// get parameter input index k. subgraph name + “_output_" + "k"
|
||||
auto pos = subgraph->name.size() + sizeof("_output_");
|
||||
auto pos2 = node->name.find('_', pos);
|
||||
auto idx_str = node->name.substr(pos - 1, pos2);
|
||||
partial_idx = std::stoi(idx_str);
|
||||
}
|
||||
subgraph_output_map.insert(std::pair<int, int>{subgraph_output, partial_outputs[partial_idx]});
|
||||
tmp_outputs_order.emplace_back(partial_idx, partial_outputs[partial_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &subgraph_node : subgraph_nodes) {
|
||||
|
@ -382,6 +380,10 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> new_subgraph_outputs{};
|
||||
std::transform(tmp_outputs_order.begin(), tmp_outputs_order.end(), std::back_inserter(new_subgraph_outputs),
|
||||
[](std::pair<int, int> iter) { return iter.second; });
|
||||
subgraph_outputs.assign(new_subgraph_outputs.begin(), new_subgraph_outputs.end());
|
||||
|
||||
return RET_OK;
|
||||
|
@ -416,102 +418,6 @@ STATUS SingleSwitchPass::ConcatBodySubgraphInputAndOutput() {
|
|||
return ret;
|
||||
}
|
||||
|
||||
STATUS SingleSwitchPass::ConvertSwitchToSelect() {
|
||||
MS_ASSERT(switch_node_->inputIndex.size() >= 3);
|
||||
MS_ASSERT(switch_node_->inputIndex.size() % 2 != 0);
|
||||
MS_ASSERT(switch_node_->outputIndex.size() * 2 + 1 == switch_node_->inputIndex.size());
|
||||
auto bool_index = switch_node_->inputIndex.front();
|
||||
|
||||
// insert switch node1
|
||||
auto switch_node1 = std::make_unique<CNodeT>();
|
||||
switch_node1->name = switch_node_->name + "-Switch-1";
|
||||
switch_node1->primitive = std::make_unique<PrimitiveT>();
|
||||
switch_node1->primitive->value.type = PrimitiveType_Switch;
|
||||
switch_node1->primitive->value.value = new (std::nothrow) SwitchT();
|
||||
switch_node1->inputIndex = {bool_index};
|
||||
std::vector<int> part_one_input_index(
|
||||
switch_node_->inputIndex.begin() + 1,
|
||||
switch_node_->inputIndex.begin() + 1 + (switch_node_->inputIndex.size() - 1) / 2);
|
||||
switch_node1->inputIndex.insert(switch_node1->inputIndex.end(), part_one_input_index.begin(),
|
||||
part_one_input_index.end());
|
||||
std::vector<std::unique_ptr<TensorT>> switch_output_tensors1(part_one_input_index.size() * 2);
|
||||
std::vector<int> switch_output_indexes1(part_one_input_index.size() * 2);
|
||||
int i = 0;
|
||||
for (const auto &input_index : part_one_input_index) {
|
||||
auto &switch_in_tensor = graph_->allTensors.at(input_index);
|
||||
auto tensor1 = NewTensor(switch_in_tensor);
|
||||
auto tensor2 = NewTensor(switch_in_tensor);
|
||||
switch_output_tensors1[i] = std::move(tensor1);
|
||||
switch_output_tensors1[part_one_input_index.size() + i] = std::move(tensor2);
|
||||
switch_output_indexes1[i] = graph_->allTensors.size() - 1 + i;
|
||||
switch_output_indexes1[part_one_input_index.size() + i] =
|
||||
graph_->allTensors.size() - 1 + i + part_one_input_index.size();
|
||||
i++;
|
||||
}
|
||||
for (auto &tensor : switch_output_tensors1) {
|
||||
graph_->allTensors.emplace_back(std::move(tensor));
|
||||
}
|
||||
switch_node1->outputIndex.insert(switch_node1->outputIndex.begin(), switch_output_indexes1.begin(),
|
||||
switch_output_indexes1.end());
|
||||
|
||||
// insert switch node2
|
||||
auto switch_node2 = std::make_unique<CNodeT>();
|
||||
switch_node2->name = switch_node_->name + "-Switch-1";
|
||||
switch_node2->primitive = std::make_unique<PrimitiveT>();
|
||||
switch_node2->primitive->value.type = PrimitiveType_Switch;
|
||||
switch_node2->primitive->value.value = new (std::nothrow) SwitchT();
|
||||
switch_node2->inputIndex = {bool_index};
|
||||
|
||||
std::vector<int> part_two_input_index(
|
||||
switch_node_->inputIndex.begin() + 1 + (switch_node_->inputIndex.size() - 1) / 2, switch_node_->inputIndex.end());
|
||||
switch_node2->inputIndex.insert(switch_node2->inputIndex.end(), part_two_input_index.begin(),
|
||||
part_two_input_index.end());
|
||||
std::vector<std::unique_ptr<TensorT>> switch_output_tensors2(part_two_input_index.size() * 2);
|
||||
std::vector<int> switch_output_indexes2(part_two_input_index.size() * 2);
|
||||
i = 0;
|
||||
for (const auto &input_index : part_two_input_index) {
|
||||
auto &switch_in_tensor = graph_->allTensors.at(input_index);
|
||||
auto tensor1 = NewTensor(switch_in_tensor);
|
||||
auto tensor2 = NewTensor(switch_in_tensor);
|
||||
switch_output_tensors2[i] = std::move(tensor1);
|
||||
switch_output_tensors2[part_two_input_index.size() + i] = std::move(tensor2);
|
||||
switch_output_indexes2[i] = graph_->allTensors.size() - 1 + i;
|
||||
switch_output_indexes2[part_two_input_index.size() + i] =
|
||||
graph_->allTensors.size() - 1 + i + part_two_input_index.size();
|
||||
i++;
|
||||
}
|
||||
for (auto &tensor : switch_output_tensors2) {
|
||||
graph_->allTensors.emplace_back(std::move(tensor));
|
||||
}
|
||||
switch_node2->outputIndex.insert(switch_node2->outputIndex.begin(), switch_output_indexes2.begin(),
|
||||
switch_output_indexes2.end());
|
||||
|
||||
// insert merge
|
||||
auto merge_node = std::make_unique<CNodeT>();
|
||||
merge_node->name = switch_node_->name + "-Merge";
|
||||
merge_node->primitive = std::make_unique<PrimitiveT>();
|
||||
merge_node->primitive->value.type = PrimitiveType_Merge;
|
||||
merge_node->primitive->value.value = new (std::nothrow) MergeT();
|
||||
|
||||
std::vector<int> merge_input_indexes(switch_node_->outputIndex.size() * 2);
|
||||
for (i = 0; i < switch_node_->outputIndex.size(); i++) {
|
||||
merge_input_indexes[i] = switch_output_indexes1[i];
|
||||
merge_input_indexes[i + switch_node_->outputIndex.size()] =
|
||||
switch_output_indexes2[i + switch_node_->outputIndex.size()];
|
||||
merge_node->outputIndex.emplace_back(switch_node_->outputIndex.at(i));
|
||||
}
|
||||
merge_node->inputIndex.insert(merge_node->inputIndex.end(), merge_input_indexes.begin(), merge_input_indexes.end());
|
||||
graph_->nodes.emplace_back(std::move(switch_node1));
|
||||
graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1);
|
||||
graph_->nodes.emplace_back(std::move(switch_node2));
|
||||
graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1);
|
||||
graph_->nodes.emplace_back(std::move(merge_node));
|
||||
graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1);
|
||||
|
||||
RemoveUselessNode(switch_node_, graph_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS SingleSwitchPass::Run() {
|
||||
int ret = Init();
|
||||
if (ret != RET_OK) {
|
||||
|
@ -519,24 +425,6 @@ STATUS SingleSwitchPass::Run() {
|
|||
return ret;
|
||||
}
|
||||
|
||||
if (switch_node_->inputIndex.size() == kSwitchMinInputSize) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
if (cond_partial_node_->primitive->value.type != PrimitiveType_Partial ||
|
||||
body_partial_node_->primitive->value.type != PrimitiveType_Partial) {
|
||||
ret = ConvertSwitchToSelect();
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (IsLoop()) {
|
||||
ret = MoveMaxIterationToCond();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "MoveMaxIterationToCond failed, ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
ret = DoubleSwitchOutput();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoubleSwitchOutput failed, ret: " << ret;
|
||||
|
|
|
@ -45,11 +45,9 @@ class SingleSwitchPass {
|
|||
STATUS Init();
|
||||
size_t InitThisGraphIndex();
|
||||
STATUS DoubleSwitchOutput();
|
||||
STATUS MoveMaxIterationToCond();
|
||||
STATUS UpdateSwitchUser();
|
||||
STATUS ConcatCondSubgraphInputAndOutput();
|
||||
STATUS ConcatBodySubgraphInputAndOutput();
|
||||
STATUS ConvertSwitchToSelect();
|
||||
bool IsLoop();
|
||||
STATUS InsertMerge();
|
||||
STATUS UpdateSubgraphInput(const size_t &subgraph_index, schema::CNodeT *partial_node,
|
||||
|
|
|
@ -27,56 +27,71 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TopologicalSortPass::Run(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
std::vector<std::unique_ptr<schema::CNodeT>> newNodes;
|
||||
std::vector<size_t> sinkedTensorIdxes;
|
||||
// put all const tensor index into sinkedTensorIdxes
|
||||
std::vector<std::unique_ptr<schema::CNodeT>> new_nodes;
|
||||
std::vector<size_t> sinked_tensor_idxes;
|
||||
// put all const tensor index into sinked_tensor_idxes
|
||||
for (size_t i = 0; i < graph->allTensors.size(); i++) {
|
||||
if (graph->allTensors.at(i)->nodeType == schema::NodeType::NodeType_ValueNode) {
|
||||
sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), i);
|
||||
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), i);
|
||||
}
|
||||
}
|
||||
auto &oldNodes = graph->nodes;
|
||||
std::queue<std::unique_ptr<schema::CNodeT>> opQueue;
|
||||
// put all non depend node into queue
|
||||
for (auto &node : graph->nodes) {
|
||||
if (IsNodeNonDepend(node, sinkedTensorIdxes)) {
|
||||
sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), node->outputIndex.begin(), node->outputIndex.end());
|
||||
opQueue.push(std::move(node));
|
||||
}
|
||||
}
|
||||
// bfs
|
||||
while (!opQueue.empty()) {
|
||||
auto &node = opQueue.front();
|
||||
auto postNodeIdxes = GetOutputNodeIdx(*graph, *(node.get()));
|
||||
for (auto postNodeIdx : postNodeIdxes) {
|
||||
auto &postNode = oldNodes.at(postNodeIdx);
|
||||
// check if postNode is non-depended
|
||||
if (IsNodeNonDepend(postNode, sinkedTensorIdxes)) {
|
||||
sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), postNode->outputIndex.begin(), postNode->outputIndex.end());
|
||||
opQueue.push(std::move(postNode));
|
||||
auto &old_nodes = graph->nodes;
|
||||
std::queue<std::unique_ptr<schema::CNodeT>> op_queue;
|
||||
// put all none depend node into queue
|
||||
for (size_t i = 0; i < graph->subGraph.size(); i++) {
|
||||
std::vector<unsigned int> new_subgraph_node_indices = {};
|
||||
auto subgraph_node_indices = graph->subGraph[i]->nodeIndices;
|
||||
|
||||
for (size_t j = 0; j < subgraph_node_indices.size(); j++) {
|
||||
auto &node = old_nodes[subgraph_node_indices[j]];
|
||||
if (IsNodeNonDepend(node, sinked_tensor_idxes)) {
|
||||
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
|
||||
op_queue.push(std::move(node));
|
||||
}
|
||||
}
|
||||
newNodes.emplace_back(std::move(node));
|
||||
opQueue.pop();
|
||||
while (!op_queue.empty()) {
|
||||
auto &node = op_queue.front();
|
||||
auto post_node_idxes = GetOutputNodeIdx(*graph, *(node.get()));
|
||||
for (auto post_node_idx : post_node_idxes) {
|
||||
if (IsContain(subgraph_node_indices, (unsigned int)(post_node_idx))) {
|
||||
auto &post_node = old_nodes.at(post_node_idx);
|
||||
// check if post_node is non-depended
|
||||
if (IsNodeNonDepend(post_node, sinked_tensor_idxes)) {
|
||||
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), post_node->outputIndex.begin(),
|
||||
post_node->outputIndex.end());
|
||||
op_queue.push(std::move(post_node));
|
||||
}
|
||||
}
|
||||
}
|
||||
new_nodes.emplace_back(std::move(node));
|
||||
new_subgraph_node_indices.push_back(new_nodes.size() - 1);
|
||||
op_queue.pop();
|
||||
}
|
||||
graph->subGraph[i]->nodeIndices.swap(new_subgraph_node_indices);
|
||||
}
|
||||
if (newNodes.size() != oldNodes.size()) {
|
||||
MS_LOG(ERROR) << "Unknow error in TopologicalSort, oldNodesSize: " << oldNodes.size()
|
||||
<< ", newNodesSize: " << newNodes.size();
|
||||
if (new_nodes.size() != old_nodes.size()) {
|
||||
MS_LOG(ERROR) << "Unknow error in TopologicalSort, old_nodes size: " << old_nodes.size()
|
||||
<< ", new_nodes size: " << new_nodes.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
graph->nodes.swap(newNodes);
|
||||
graph->nodes.swap(new_nodes);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool TopologicalSortPass::IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node,
|
||||
const std::vector<size_t> &sinkedTensorIdxes) {
|
||||
const std::vector<size_t> &sinked_tensor_idxes) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
for (auto inputIdx : node->inputIndex) {
|
||||
if (!IsContain(sinkedTensorIdxes, size_t(inputIdx))) {
|
||||
return false;
|
||||
}
|
||||
if (node->primitive->value.type == schema::PrimitiveType_Merge) {
|
||||
auto node_input_index = node->inputIndex;
|
||||
MS_ASSERT(node_input_index.size() % 2 == 0);
|
||||
return std::all_of(node_input_index.begin(), node_input_index.begin() + node_input_index.size() / 2,
|
||||
[&](size_t input_idx) { return IsContain(sinked_tensor_idxes, input_idx); }) ||
|
||||
std::all_of(node_input_index.begin() + node_input_index.size() / 2, node_input_index.end(),
|
||||
[&](size_t input_idx) { return IsContain(sinked_tensor_idxes, input_idx); });
|
||||
} else {
|
||||
return std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
|
||||
[&](size_t input_idx) { return IsContain(sinked_tensor_idxes, size_t(input_idx)); });
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -54,6 +54,7 @@ FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::s
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph"));
|
||||
return func_graph_ptr_;
|
||||
}
|
||||
|
||||
|
|
|
@ -80,6 +80,7 @@ FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::st
|
|||
MS_LOG(ERROR) << "convert graph outputs failed.";
|
||||
return nullptr;
|
||||
}
|
||||
func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph"));
|
||||
return func_graph_ptr_;
|
||||
}
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
}
|
||||
primitive->value.type = schema::PrimitiveType_Mul;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tf_op.op() == "Div") {
|
||||
} else if (tf_op.op() == "Div" || tf_op.op() == "RealDiv") {
|
||||
auto attr = std::make_unique<schema::DivT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
|
@ -154,6 +154,7 @@ TFNodeRegistrar g_tfAddV2Parser("AddV2", new TFArithmeticParser());
|
|||
TFNodeRegistrar g_tfSubParser("Sub", new TFArithmeticParser());
|
||||
TFNodeRegistrar g_tfMulParser("Mul", new TFArithmeticParser());
|
||||
TFNodeRegistrar g_tfDivParser("Div", new TFArithmeticParser());
|
||||
TFNodeRegistrar g_tfRealDivParser("RealDiv", new TFArithmeticParser());
|
||||
TFNodeRegistrar g_tfMaximumParser("Maximum", new TFArithmeticParser());
|
||||
TFNodeRegistrar g_tfMinimumParser("Minimum", new TFArithmeticParser());
|
||||
TFNodeRegistrar g_tfGreaterParser("Greater", new TFArithmeticParser());
|
||||
|
|
|
@ -37,10 +37,11 @@ static const std::vector<schema::PrimitiveType> tensorListOutputOpList = {
|
|||
|
||||
AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) {
|
||||
AnfNodePtr ret = nullptr;
|
||||
if (anf_node_map.find(name) != anf_node_map.end()) {
|
||||
ret = anf_node_map.at(name);
|
||||
auto flat_anf_name = TensorFlowUtils::GetFlattenNodeName(name);
|
||||
if (anf_node_map.find(flat_anf_name) != anf_node_map.end()) {
|
||||
ret = anf_node_map.at(flat_anf_name);
|
||||
} else if (anf_node_map.find(name + ":0") != anf_node_map.end()) {
|
||||
ret = anf_node_map.at(name + ":0");
|
||||
ret = anf_node_map.at(flat_anf_name + ":0");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
@ -212,6 +213,17 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value
|
|||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
} else if (type == kObjectTypeString) {
|
||||
auto tensor_data = new (std::nothrow) string;
|
||||
if (tensor_proto.string_val_size() == 1) {
|
||||
string value = tensor_proto.string_val(0);
|
||||
*tensor_data = value;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "string size bigger than one, not support.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
tensor_size = (*tensor_data).size();
|
||||
param_value->SetTensorData(tensor_data, tensor_size);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport dataType: " << type;
|
||||
return RET_ERROR;
|
||||
|
@ -318,6 +330,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return nullptr;
|
||||
}
|
||||
anf_root_graph_->set_attr("graph_name", MakeValue("main_graph"));
|
||||
|
||||
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
|
||||
auto &node_def = tf_root_graph_->node(i);
|
||||
|
@ -364,7 +377,6 @@ STATUS TFModelParser::ConvertSubgraph() {
|
|||
std::map<CNodePtr, FuncGraphPtr> while_cond_map;
|
||||
std::map<CNodePtr, FuncGraphPtr> while_body_map;
|
||||
for (int i = 0; i < subgraph_size; i++) {
|
||||
std::vector<ParameterPtr> sub_graph_inputs;
|
||||
auto &tf_sub_fuction = graph_def_liarary.function(i);
|
||||
auto &tf_sub_signature = tf_sub_fuction.signature();
|
||||
auto input_arg_size = tf_sub_signature.input_arg_size();
|
||||
|
@ -381,13 +393,17 @@ STATUS TFModelParser::ConvertSubgraph() {
|
|||
}
|
||||
|
||||
FuncGraphPtr sub_func_graph = std::make_shared<FuncGraph>();
|
||||
sub_func_graph->set_attr("graph_name", MakeValue(sub_graph_name));
|
||||
std::unordered_map<std::string, AnfNodePtr> anf_sub_node_map;
|
||||
// convert sub graph inputs
|
||||
std::vector<ParameterPtr> sub_graph_inputs;
|
||||
for (int j = 0; j < input_arg_size; j++) {
|
||||
auto &input_arg = tf_sub_signature.input_arg(j);
|
||||
auto paramter = sub_func_graph->add_parameter();
|
||||
paramter->set_name(input_arg.name());
|
||||
anf_sub_node_map[input_arg.name()] = paramter;
|
||||
auto root_while_inputs = while_cnode->inputs();
|
||||
paramter->set_abstract(root_while_inputs[j + 1]->abstract());
|
||||
sub_graph_inputs.emplace_back(paramter);
|
||||
}
|
||||
std::map<std::string, const tensorflow::NodeDef *> tf_sub_node_map;
|
||||
|
@ -452,8 +468,19 @@ STATUS TFModelParser::ConvertSubgraph() {
|
|||
}
|
||||
// hardcode subgraph inputs name
|
||||
for (size_t j = 0; j < sub_graph_inputs.size(); j++) {
|
||||
sub_graph_inputs[j]->set_name("graph" + std::to_string(i) + "_input_" + std::to_string(j) + "parameter");
|
||||
sub_graph_inputs[j]->set_name(sub_graph_name + "_input_" + std::to_string(j) + "_parameter");
|
||||
}
|
||||
// hardcode subgraph outputs name
|
||||
for (size_t j = 1; j < sub_output_nodes.size(); j++) {
|
||||
if (utils::isa<CNodePtr>(sub_output_nodes[j])) {
|
||||
sub_output_nodes[j]->cast<CNodePtr>()->set_fullname_with_scope(sub_graph_name + "_output_" +
|
||||
std::to_string(j - 1) + "_cnode");
|
||||
} else if (utils::isa<ParameterPtr>(sub_output_nodes[j])) {
|
||||
sub_output_nodes[j]->cast<ParameterPtr>()->set_name(sub_graph_name + "_output_" + std::to_string(j - 1) +
|
||||
"_parameter");
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name;
|
||||
}
|
||||
auto status = WhileNodePostProcess(while_cond_map, while_body_map);
|
||||
|
@ -469,9 +496,8 @@ STATUS TFModelParser::WhileNodePostProcess(const std::map<CNodePtr, FuncGraphPtr
|
|||
MS_LOG(ERROR) << "while cond body size error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<FuncGraphPtr> roots = {anf_root_graph_};
|
||||
auto root_func_manager = std::make_shared<FuncGraphManager>(roots);
|
||||
anf_root_graph_->set_manager(root_func_manager);
|
||||
static auto root_func_manager = Manage(anf_root_graph_);
|
||||
|
||||
for (auto &kv : while_cond_map) {
|
||||
auto while_node = kv.first;
|
||||
auto &cond_sub_graph = kv.second;
|
||||
|
@ -633,6 +659,11 @@ STATUS TFModelParser::ConvertRootGraphOutputs() {
|
|||
for (auto &pair : tf_root_graph_nodes_) {
|
||||
for (int i = 0; i < pair.second->input_size(); ++i) {
|
||||
all_node_inputs.insert(TensorFlowUtils::GetNodeName(pair.second->input(i)));
|
||||
auto input_name = pair.second->input(i);
|
||||
if (input_name[0] == '^') {
|
||||
input_name.erase(0, 1);
|
||||
}
|
||||
all_node_inputs.insert(input_name);
|
||||
}
|
||||
}
|
||||
for (auto &pair : tf_root_graph_nodes_) {
|
||||
|
@ -644,7 +675,7 @@ STATUS TFModelParser::ConvertRootGraphOutputs() {
|
|||
auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_);
|
||||
auto anf_node = GetAnfNode(origin_name, anf_root_node_map_);
|
||||
if (anf_node == nullptr) {
|
||||
MS_LOG(ERROR) << "can't find anf node";
|
||||
MS_LOG(ERROR) << "can't find anf node: " << origin_name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
output_nodes.push_back(anf_node);
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_ragged_range_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFRaggedRangeParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(INFO) << "TF RaggedRangeParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::RangeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "starts", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The starts attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->start = static_cast<int32_t>(attr_value.i());
|
||||
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "limits", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The limits attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->limit = static_cast<int32_t>(attr_value.i());
|
||||
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "deltas", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The deltas attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->delta = static_cast<int32_t>(attr_value.i());
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Range;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
auto status = AddOpInput(tf_op, 0, inputs);
|
||||
return status;
|
||||
}
|
||||
TFNodeRegistrar g_tfRaggedRangeParser("RaggedRange", new TFRaggedRangeParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RAGFED_RANGE_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RAGGED_RANGE_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFRaggedRangeParser : public TFNodeParser {
|
||||
public:
|
||||
TFRaggedRangeParser() = default;
|
||||
~TFRaggedRangeParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_range_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TFRangeParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(INFO) << "TF RangeParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "New PrimitiveT failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::RangeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "start", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The start attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->start = static_cast<int32_t>(attr_value.i());
|
||||
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "limit", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The limit attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->limit = static_cast<int32_t>(attr_value.i());
|
||||
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "delta", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The delta attr should be specified";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->delta = static_cast<int32_t>(attr_value.i());
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Range;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
auto status = AddOpInput(tf_op, 0, inputs);
|
||||
return status;
|
||||
}
|
||||
TFNodeRegistrar g_tfRangeParser("Range", new TFRangeParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANGE_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANGE_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFRangeParser : public TFNodeParser {
|
||||
public:
|
||||
TFRangeParser() = default;
|
||||
~TFRangeParser() override = default;
|
||||
|
||||
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ROUND_PARSER_H_
|
|
@ -9,7 +9,7 @@
|
|||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WRRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
#include <regex>
|
||||
#include <unordered_map>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
|
|
|
@ -76,6 +76,7 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
func_graph_->set_attr("graph_name", MakeValue("main_graph"));
|
||||
return func_graph_;
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,22 @@
|
|||
#include "mindspore/lite/src/ops/primitive_c.h"
|
||||
#include "tools/anf_importer/import_from_meta_graphT.h"
|
||||
|
||||
using mindspore::lite::RET_INFER_INVALID;
|
||||
|
||||
namespace mindspore::opt {
|
||||
|
||||
ParamValueLitePtr NewParamValueLitePtr(lite::Tensor *tensor) {
|
||||
auto para_value_lite = std::make_shared<ParamValueLite>();
|
||||
if (para_value_lite == nullptr) {
|
||||
MS_LOG(ERROR) << "new ParamValueLite failed";
|
||||
return nullptr;
|
||||
}
|
||||
para_value_lite->set_tensor_shape(tensor->shape());
|
||||
para_value_lite->set_tensor_type(tensor->data_type());
|
||||
para_value_lite->set_format(tensor->format());
|
||||
return para_value_lite;
|
||||
}
|
||||
|
||||
abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) {
|
||||
MS_ASSERT(nullptr != tensor);
|
||||
std::vector<int> shape(tensor->shape());
|
||||
|
@ -33,15 +48,30 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li
|
|||
MS_LOG(ERROR) << "new AbstractTensor failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto new_value = std::make_shared<ParamValueLite>();
|
||||
if (new_value == nullptr) {
|
||||
|
||||
auto para_value_lite = NewParamValueLitePtr(tensor);
|
||||
if (para_value_lite == nullptr) {
|
||||
MS_LOG(ERROR) << "new ParamValueLite failed";
|
||||
return nullptr;
|
||||
}
|
||||
new_value->set_tensor_shape(tensor->shape());
|
||||
new_value->set_tensor_type(tensor->data_type());
|
||||
new_value->set_format(tensor->format());
|
||||
new_abstract->set_value(new_value);
|
||||
|
||||
if (type_id == kObjectTypeTensorType) {
|
||||
auto tensor_list = dynamic_cast<lite::TensorList *>(tensor);
|
||||
if (tensor_list == nullptr) {
|
||||
MS_LOG(ERROR) << "cast tensor_list failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto tensor_info = new int[tensor_list->element_shape().size() + 2];
|
||||
tensor_info[0] = tensor_list->tensors_data_type();
|
||||
tensor_info[1] = tensor_list->element_shape().size();
|
||||
for (size_t i = 0; i < tensor_list->element_shape().size(); ++i) {
|
||||
tensor_info[i + 2] = tensor_list->element_shape()[i];
|
||||
}
|
||||
para_value_lite->set_tensor_addr(tensor_info);
|
||||
para_value_lite->set_tensor_size(tensor_list->element_shape().size() + 2);
|
||||
}
|
||||
|
||||
new_abstract->set_value(para_value_lite);
|
||||
return new_abstract;
|
||||
}
|
||||
|
||||
|
@ -121,13 +151,13 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l
|
|||
}
|
||||
|
||||
if (utils::isa<ValueNodePtr>(cnode->input(i))) {
|
||||
MS_LOG(WARNING) << "input is value node";
|
||||
MS_LOG(WARNING) << cnode->fullname_with_scope() << "'s input[" << i << "] is value node";
|
||||
continue;
|
||||
}
|
||||
|
||||
AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, i);
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of CNode is nullptr";
|
||||
MS_LOG(ERROR) << "Abstract of CNode: " << cnode->fullname_with_scope() << " is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
|
||||
|
@ -194,7 +224,7 @@ STATUS InferShapePass::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<
|
|||
MS_ASSERT(output_tensors != nullptr);
|
||||
auto abstract = cnode->abstract();
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "abstract is nullptr";
|
||||
MS_LOG(ERROR) << "node " << cnode->fullname_with_scope() << " abstract is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<TypeId> types;
|
||||
|
@ -264,7 +294,62 @@ STATUS InferShapePass::SetCNodeAbstract(const std::vector<lite::Tensor *> &outpu
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int InferShapePass::StrIsContain(const std::vector<std::string> &total, const std::string &aim) {
|
||||
for (size_t i = 0; i < total.size(); i++) {
|
||||
if (aim.find(total[i]) != std::string::npos) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
STATUS InferShapePass::SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph) {
|
||||
// hard code construct input parameter name
|
||||
std::vector<std::string> inputs_names{};
|
||||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||
inputs_names.emplace_back("_input_" + std::to_string(i - 1) + "_parameter");
|
||||
}
|
||||
// copy cnode input to func_graph input
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (utils::isa<ParameterPtr>(node)) {
|
||||
auto pos = StrIsContain(inputs_names, node->fullname_with_scope());
|
||||
if (pos != -1) {
|
||||
auto pnode = utils::cast<ParameterPtr>(node);
|
||||
auto input_pnode = utils::cast<ParameterPtr>(cnode->input(pos + 1));
|
||||
MS_ASSERT(pnode != nullptr);
|
||||
pnode->set_abstract(input_pnode->abstract());
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS InferShapePass::SwitchCNodeInferShape(const CNodePtr &switch_cnode) {
|
||||
auto body_partial_cnode = switch_cnode->input(2)->cast<CNodePtr>();
|
||||
MS_ASSERT(body_partial_cnode != nullptr);
|
||||
auto body_vnode = body_partial_cnode->input(0)->cast<ValueNodePtr>();
|
||||
MS_ASSERT(body_vnode != nullptr);
|
||||
auto body_fg = GetValueNode<FuncGraphPtr>(body_vnode);
|
||||
MS_ASSERT(body_fg != nullptr);
|
||||
AbstractBasePtrList abstract_list;
|
||||
auto body_fg_output_cnode = utils::cast<CNodePtr>(body_fg->output());
|
||||
for (auto &cnode : body_fg_output_cnode->inputs()) {
|
||||
if (!utils::isa<CNodePtr>(cnode) && !utils::isa<ParameterPtr>(cnode)) {
|
||||
continue;
|
||||
}
|
||||
abstract_list.push_back(cnode->abstract());
|
||||
}
|
||||
|
||||
switch_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
|
||||
if (func_graph->has_flag("HasInferShaped")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) {
|
||||
MS_LOG(INFO) << "The framework type of model should be tf/tflite.";
|
||||
return false;
|
||||
|
@ -287,8 +372,14 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
auto origin_primc = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(cnode->input(0));
|
||||
if (origin_primc == nullptr) {
|
||||
MS_LOG(ERROR) << "origin_primc is nullptr";
|
||||
return false;
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
||||
if (sub_func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "node " << node->fullname_with_scope() << "'s origin_primc is nullptr";
|
||||
return false;
|
||||
} else {
|
||||
MS_LOG(WARNING) << "subgraph infer shape invalid.";
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
}
|
||||
auto origin_primt = origin_primc->primitiveT();
|
||||
if (origin_primt == nullptr) {
|
||||
|
@ -296,6 +387,15 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
|
|||
return false;
|
||||
}
|
||||
auto type = GetCNodeType(cnode);
|
||||
|
||||
if (type == schema::PrimitiveType_Switch) {
|
||||
int ret = SwitchCNodeInferShape(cnode);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "PartialCNodeInferShape failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if ((type == schema::PrimitiveType_TupleGetItem) ||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
(type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) ||
|
||||
|
|
|
@ -41,6 +41,9 @@ class InferShapePass : public Pass {
|
|||
STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors);
|
||||
STATUS SetParameterAbstract(const ParameterPtr ¶meter);
|
||||
STATUS SetCNodeAbstract(const std::vector<lite::Tensor *> &output_tensors, const std::shared_ptr<CNode> &cnode);
|
||||
STATUS SwitchCNodeInferShape(const CNodePtr &cnode);
|
||||
int StrIsContain(const std::vector<std::string> &total, const std::string &aim);
|
||||
int SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
FmkType fmk_type = lite::converter::FmkType_ONNX;
|
||||
|
|
|
@ -0,0 +1,181 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/optimizer/graph/while_pass.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "mindspore/lite/include/errorcode.h"
|
||||
#include "mindspore/lite/src/ops/primitive_c.h"
|
||||
#include "tools/anf_importer/import_from_meta_graphT.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/tensor.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/ops/switch.h"
|
||||
#include "src/ops/partial.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
|
||||
ValueNodePtr WhilePass::GetSwitchAnfPrim() {
|
||||
auto switch_primitiveT = new (std::nothrow) schema::PrimitiveT;
|
||||
if (switch_primitiveT == nullptr) {
|
||||
MS_LOG(ERROR) << "new switch_primitiveT failed";
|
||||
return nullptr;
|
||||
}
|
||||
switch_primitiveT->value.type = schema::PrimitiveType_Switch;
|
||||
switch_primitiveT->value.value = new (std::nothrow) schema::SwitchT;
|
||||
if (switch_primitiveT->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "new MakeTupleT failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto partial_prim = std::make_shared<lite::Partial>(switch_primitiveT);
|
||||
ValueNodePtr partial_anf_prim = NewValueNode(partial_prim);
|
||||
return partial_anf_prim;
|
||||
}
|
||||
|
||||
void WhilePass::ReplaceInput(const std::vector<AnfNodePtr> &node_list, AnfNodePtr new_input_cnode,
|
||||
std::string para_name) {
|
||||
for (auto &node : node_list) {
|
||||
if (utils::isa<CNodePtr>(node)) {
|
||||
auto cnode = utils::cast<CNodePtr>(node);
|
||||
for (size_t k = 0; k < cnode->inputs().size(); k++) {
|
||||
if (!utils::isa<ParameterPtr>(cnode->input(k))) {
|
||||
continue;
|
||||
}
|
||||
auto para_input = utils::cast<ParameterPtr>(cnode->input(k));
|
||||
if (para_input->name() == para_name) {
|
||||
cnode->set_input(k, new_input_cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool WhilePass::Run(const FuncGraphPtr &graph) {
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
static int count = 0;
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
if (opt::GetCNodeType(node) != schema::PrimitiveType_While) {
|
||||
continue;
|
||||
}
|
||||
auto while_cnode = node->cast<CNodePtr>();
|
||||
MS_ASSERT(while_cnode != nullptr);
|
||||
if (while_cnode->inputs().size() < kWhileMinInputSize) {
|
||||
MS_LOG(ERROR) << "while input is not right.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// the order is fixed.
|
||||
auto cond_vnode = while_cnode->input(kWhileCondIndex);
|
||||
auto body_vnode = while_cnode->input(kWhileBodyIndex);
|
||||
|
||||
// body_vnode->cast<ValueNodePtr>()->set_value()
|
||||
auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_vnode);
|
||||
auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode);
|
||||
|
||||
if (cond_fg == nullptr || body_fg == nullptr) {
|
||||
MS_LOG(ERROR) << "Get value as func_graph failed.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_FAILED);
|
||||
return false;
|
||||
}
|
||||
|
||||
// create cond partial cnode
|
||||
std::vector<AnfNodePtr> cond_partial_op_inputs{cond_vnode};
|
||||
|
||||
// create body partial cnode
|
||||
std::vector<AnfNodePtr> body_partial_op_inputs{body_vnode};
|
||||
|
||||
// add while op input to cond_cnode and body_cnode
|
||||
cond_partial_op_inputs.insert(cond_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
|
||||
while_cnode->inputs().end());
|
||||
body_partial_op_inputs.insert(body_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
|
||||
while_cnode->inputs().end());
|
||||
|
||||
static int idx = 0;
|
||||
auto cond_partial_node = graph->NewCNode(cond_partial_op_inputs);
|
||||
cond_partial_node->set_fullname_with_scope("Partial-while-cond-" + std::to_string(idx));
|
||||
cond_partial_node->set_abstract(cond_fg->output()->abstract());
|
||||
|
||||
auto body_partial_node = graph->NewCNode(body_partial_op_inputs);
|
||||
body_partial_node->set_fullname_with_scope("Partial-while-body-" + std::to_string(idx));
|
||||
idx++;
|
||||
|
||||
// concat body_fg output to cond_fg input
|
||||
auto body_output = body_fg->output();
|
||||
auto body_output_cnode = utils::cast<CNodePtr>(body_output);
|
||||
auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(body_output_cnode->input(0));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Get PrimitiveC of node:" << body_output_cnode->fullname_with_scope() << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// concat body to cond
|
||||
std::vector<AnfNodePtr> body_to_cond_inputs{cond_vnode};
|
||||
if ((schema::PrimitiveType)(prim->Type()) == schema::PrimitiveType_MakeTuple) {
|
||||
for (size_t i = 1; i < body_output_cnode->inputs().size(); ++i) {
|
||||
body_to_cond_inputs.emplace_back(body_output_cnode->input(i));
|
||||
}
|
||||
} else {
|
||||
body_to_cond_inputs.emplace_back(body_output_cnode->input(1));
|
||||
}
|
||||
|
||||
// concat body to cond
|
||||
auto body_to_cond_cnode = body_fg->NewCNode(body_to_cond_inputs);
|
||||
body_to_cond_cnode->set_fullname_with_scope("Partial-while-body-to-cond");
|
||||
auto body_fg_manager = body_fg->manager();
|
||||
body_fg_manager->Replace(body_fg->output(), body_to_cond_cnode);
|
||||
body_fg->set_output(body_to_cond_cnode);
|
||||
body_partial_node->set_abstract(cond_fg->output()->abstract());
|
||||
|
||||
// create switch cnode
|
||||
ValueNodePtr switch_anf_primitive = GetSwitchAnfPrim();
|
||||
if (switch_anf_primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "GetSwitchAnfPrim failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// insert switch node
|
||||
std::vector<AnfNodePtr> switch_op_inputs = {switch_anf_primitive, cond_partial_node, body_partial_node};
|
||||
auto switch_cnode = graph->NewCNode(switch_op_inputs);
|
||||
switch_cnode->set_fullname_with_scope("Switch-" + std::to_string(count++));
|
||||
|
||||
AbstractBasePtrList abstract_list;
|
||||
auto body_fg_output_cnode = utils::cast<CNodePtr>(body_fg->output());
|
||||
for (auto &cnode : body_fg_output_cnode->inputs()) {
|
||||
if (!utils::isa<CNodePtr>(cnode) && !utils::isa<ParameterPtr>(cnode)) {
|
||||
continue;
|
||||
}
|
||||
abstract_list.push_back(cnode->abstract());
|
||||
}
|
||||
|
||||
switch_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
|
||||
// create cond partial cnode
|
||||
auto manager = graph->manager();
|
||||
auto node_users = manager->node_users()[while_cnode];
|
||||
for (auto &node_user : node_users) {
|
||||
manager->SetEdge(node_user.first, node_user.second, switch_cnode);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WHILE_PASS_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WHILE_PASS_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "src/param_value_lite.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
namespace mindspore::opt {
|
||||
class WhilePass : public Pass {
|
||||
public:
|
||||
WhilePass() : Pass("while_pass") {}
|
||||
~WhilePass() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
void ReplaceInput(const std::vector<AnfNodePtr> &node_list, AnfNodePtr new_input_cnode, std::string para_name);
|
||||
ValueNodePtr GetSwitchAnfPrim();
|
||||
|
||||
const size_t kWhileMinInputSize = 3;
|
||||
const size_t kWhileCondIndex = 1;
|
||||
const size_t kWhileBodyIndex = 2;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_
|
Loading…
Reference in New Issue