!12310 test
From: @xu_anyue Reviewed-by: @zh_qh,@zhanghaibo5 Signed-off-by: @zh_qh
This commit is contained in:
commit
eba691ba25
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
@ -32,6 +33,45 @@
|
|||
#include "tools/common/graph_util.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
|
||||
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
|
||||
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
|
||||
std::vector<AnfNodePtr> vecs;
|
||||
if (node == nullptr) {
|
||||
return vecs;
|
||||
}
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto &inputs = cnode->inputs();
|
||||
// Check if free variables used.
|
||||
for (const auto &input : inputs) {
|
||||
auto input_fg = GetValueNode<FuncGraphPtr>(input);
|
||||
if (input_fg) {
|
||||
for (auto &fv : input_fg->free_variables_nodes()) {
|
||||
if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
|
||||
vecs.push_back(fv);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
(void)vecs.insert(vecs.end(), inputs.begin(), inputs.end());
|
||||
}
|
||||
return vecs;
|
||||
};
|
||||
|
||||
std::list<CNodePtr> cnodes;
|
||||
auto nodes = TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph);
|
||||
for (const auto &node : nodes) {
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
if (cnode) {
|
||||
cnodes.push_back(cnode);
|
||||
}
|
||||
}
|
||||
return cnodes;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) {
|
||||
bool has_make_tuple = false;
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
|
@ -220,7 +260,7 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
|
|||
const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive,
|
||||
const std::unique_ptr<schema::SubGraphT> &sub_graphT) {
|
||||
int ret = RET_OK;
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
auto cnodes = GetOrderedCNodes(func_graph);
|
||||
for (const auto &cnode : cnodes) {
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
|
|
|
@ -178,6 +178,7 @@ int RunConverter(int argc, const char **argv) {
|
|||
delete fb_graph;
|
||||
MS_LOG(INFO) << "CONVERT RESULT SUCCESS:" << status;
|
||||
std::cout << "CONVERT RESULT SUCCESS:" << status << std::endl;
|
||||
|
||||
return status;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
Loading…
Reference in New Issue