!2153 Add transdata for output node in pynative hook mode

Merge pull request !2153 from JoyLvliang/pynative-insert-transdata-for-hook-mode
This commit is contained in:
mindspore-ci-bot 2020-06-16 20:20:13 +08:00 committed by Gitee
commit 8903b50042
4 changed files with 11 additions and 1 deletions

View File

@ -51,7 +51,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_);
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) { if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) {
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) { if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
return new_node; return new_node;
} }

View File

@ -74,6 +74,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
precompile_only_ = false; precompile_only_ = false;
auto_mixed_precision_flag_ = false; auto_mixed_precision_flag_ = false;
enable_pynative_infer_ = false; enable_pynative_infer_ = false;
enable_pynative_hook_ = false;
enable_dynamic_mem_pool_ = true; enable_dynamic_mem_pool_ = true;
graph_memory_max_size_ = "0"; graph_memory_max_size_ = "0";
variable_memory_max_size_ = "0"; variable_memory_max_size_ = "0";

View File

@ -64,6 +64,9 @@ class MsContext {
bool enable_pynative_infer() const { return enable_pynative_infer_; } bool enable_pynative_infer() const { return enable_pynative_infer_; }
void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; } void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; }
bool enable_pynative_hook() const { return enable_pynative_hook_; }
void set_enable_pynative_hook(bool enable_pynative_hook) { enable_pynative_hook_ = enable_pynative_hook; }
bool enable_task_sink() const { return enable_task_sink_; } bool enable_task_sink() const { return enable_task_sink_; }
void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; } void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; }
@ -161,6 +164,7 @@ class MsContext {
uint32_t device_id_; uint32_t device_id_;
int execution_mode_; int execution_mode_;
bool enable_pynative_infer_; bool enable_pynative_infer_;
bool enable_pynative_hook_;
bool save_graphs_flag_; bool save_graphs_flag_;
std::string save_graphs_path_; std::string save_graphs_path_;
uint32_t tsd_ref_; uint32_t tsd_ref_;

View File

@ -277,6 +277,11 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
for (auto &prim : cut_list_) { for (auto &prim : cut_list_) {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == node_prim->name()) { if (prim->name() == node_prim->name()) {
if (prim->name() == prim::kPrimBpropCut->name()) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_enable_pynative_hook(true);
}
return true; return true;
} }
} }