From: @xu_anyue
Reviewed-by: @zh_qh,@zhanghaibo5
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-02-09 17:45:40 +08:00 committed by Gitee
commit eba691ba25
2 changed files with 42 additions and 1 deletions

View File

@ -16,6 +16,7 @@
#include "tools/anf_exporter/anf_exporter.h"
#include <list>
#include <memory>
#include <string>
#include <utility>
@ -32,6 +33,45 @@
#include "tools/common/graph_util.h"
namespace mindspore::lite {
namespace {
std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
std::vector<AnfNodePtr> vecs;
if (node == nullptr) {
return vecs;
}
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
// Check if free variables used.
for (const auto &input : inputs) {
auto input_fg = GetValueNode<FuncGraphPtr>(input);
if (input_fg) {
for (auto &fv : input_fg->free_variables_nodes()) {
if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
vecs.push_back(fv);
}
}
}
}
(void)vecs.insert(vecs.end(), inputs.begin(), inputs.end());
}
return vecs;
};
std::list<CNodePtr> cnodes;
auto nodes = TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph);
for (const auto &node : nodes) {
auto cnode = dyn_cast<CNode>(node);
if (cnode) {
cnodes.push_back(cnode);
}
}
return cnodes;
}
} // namespace
void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) {
bool has_make_tuple = false;
std::vector<AnfNodePtr> inputs;
@ -220,7 +260,7 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
const size_t &subgraph_index, const bool &keep_graph, const bool &copy_primitive,
const std::unique_ptr<schema::SubGraphT> &sub_graphT) {
int ret = RET_OK;
auto cnodes = func_graph->GetOrderedCnodes();
auto cnodes = GetOrderedCNodes(func_graph);
for (const auto &cnode : cnodes) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {

View File

@ -178,6 +178,7 @@ int RunConverter(int argc, const char **argv) {
delete fb_graph;
MS_LOG(INFO) << "CONVERT RESULT SUCCESS:" << status;
std::cout << "CONVERT RESULT SUCCESS:" << status << std::endl;
return status;
}
} // namespace lite