forked from mindspore-Ecosystem/mindspore
fix mix target entry
This commit is contained in:
parent
3654f0f9ac
commit
ae3db6d4de
|
@ -282,7 +282,7 @@ bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPa
|
|||
|
||||
bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); }
|
||||
|
||||
static bool IsCtrlSink(const FuncGraphPtr &graph) {
|
||||
static bool IsCtrlSink() {
|
||||
auto ms_ctx = MsContext::GetInstance();
|
||||
if (ms_ctx->execution_mode() != kGraphMode) {
|
||||
return false;
|
||||
|
@ -297,10 +297,9 @@ static bool IsCtrlSink(const FuncGraphPtr &graph) {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (graph != nullptr && CompileGraphs::ContainMixedTarget(graph)) {
|
||||
if (!ms_ctx->is_multi_graph_sink()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -310,7 +309,21 @@ bool TaskEmitAction(const ResourcePtr &res) {
|
|||
}
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
|
||||
if (IsCtrlSink(func_graph)) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (CompileGraphs::ContainMixedTarget(func_graph)) {
|
||||
bc_ptr->set_is_multi_graph_sink(false);
|
||||
context_ptr->set_is_multi_graph_sink(false);
|
||||
context_ptr->set_loop_sink_flag(false);
|
||||
} else if (context_ptr->execution_mode() != kPynativeMode) {
|
||||
std::string device_target = context_ptr->device_target();
|
||||
if (device_target == kAscendDevice) {
|
||||
bc_ptr->set_is_multi_graph_sink(true);
|
||||
context_ptr->set_is_multi_graph_sink(true);
|
||||
}
|
||||
}
|
||||
|
||||
if (IsCtrlSink()) {
|
||||
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
|
||||
return true;
|
||||
}
|
||||
|
@ -318,19 +331,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
|
|||
if (bc_ptr->name() == kMsConvert) {
|
||||
cut_list = compile::GetMsNonlinearOps();
|
||||
}
|
||||
|
||||
std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list);
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (CompileGraphs::ContainMixedTarget(func_graph)) {
|
||||
bc_ptr->set_is_multi_graph_sink(false);
|
||||
context_ptr->set_loop_sink_flag(false);
|
||||
} else if (context_ptr->execution_mode() != kPynativeMode) {
|
||||
std::string device_target = context_ptr->device_target();
|
||||
if (device_target == kAscendDevice) {
|
||||
bc_ptr->set_is_multi_graph_sink(true);
|
||||
}
|
||||
}
|
||||
res->results()[kOutput] = compile->CompileAndLink(func_graph);
|
||||
return true;
|
||||
}
|
||||
|
@ -340,11 +341,10 @@ bool ExecuteAction(const ResourcePtr &res) {
|
|||
MS_LOG(EXCEPTION) << "Execute args error";
|
||||
}
|
||||
|
||||
if (IsCtrlSink(nullptr)) {
|
||||
if (IsCtrlSink()) {
|
||||
if (!res->results()[kOutput].is<GraphId>()) {
|
||||
MS_LOG(EXCEPTION) << "Execute args error";
|
||||
}
|
||||
|
||||
auto graph_id = res->results()[kOutput].cast<GraphId>();
|
||||
std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
|
||||
std::shared_ptr<compile::MsBackend> msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr);
|
||||
|
|
Loading…
Reference in New Issue