forked from mindspore-Ecosystem/mindspore
!2153 Add transdata for output node in pynative hook mode
Merge pull request !2153 from JoyLvliang/pynative-insert-transdata-for-hook-mode
This commit is contained in:
commit
8903b50042
|
@ -51,7 +51,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
|
||||||
AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_);
|
AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_);
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
if (ms_context->execution_mode() == kPynativeMode) {
|
if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) {
|
||||||
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
|
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
|
||||||
return new_node;
|
return new_node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,6 +74,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
||||||
precompile_only_ = false;
|
precompile_only_ = false;
|
||||||
auto_mixed_precision_flag_ = false;
|
auto_mixed_precision_flag_ = false;
|
||||||
enable_pynative_infer_ = false;
|
enable_pynative_infer_ = false;
|
||||||
|
enable_pynative_hook_ = false;
|
||||||
enable_dynamic_mem_pool_ = true;
|
enable_dynamic_mem_pool_ = true;
|
||||||
graph_memory_max_size_ = "0";
|
graph_memory_max_size_ = "0";
|
||||||
variable_memory_max_size_ = "0";
|
variable_memory_max_size_ = "0";
|
||||||
|
|
|
@ -64,6 +64,9 @@ class MsContext {
|
||||||
bool enable_pynative_infer() const { return enable_pynative_infer_; }
|
bool enable_pynative_infer() const { return enable_pynative_infer_; }
|
||||||
void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; }
|
void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; }
|
||||||
|
|
||||||
|
bool enable_pynative_hook() const { return enable_pynative_hook_; }
|
||||||
|
void set_enable_pynative_hook(bool enable_pynative_hook) { enable_pynative_hook_ = enable_pynative_hook; }
|
||||||
|
|
||||||
bool enable_task_sink() const { return enable_task_sink_; }
|
bool enable_task_sink() const { return enable_task_sink_; }
|
||||||
|
|
||||||
void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; }
|
void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; }
|
||||||
|
@ -161,6 +164,7 @@ class MsContext {
|
||||||
uint32_t device_id_;
|
uint32_t device_id_;
|
||||||
int execution_mode_;
|
int execution_mode_;
|
||||||
bool enable_pynative_infer_;
|
bool enable_pynative_infer_;
|
||||||
|
bool enable_pynative_hook_;
|
||||||
bool save_graphs_flag_;
|
bool save_graphs_flag_;
|
||||||
std::string save_graphs_path_;
|
std::string save_graphs_path_;
|
||||||
uint32_t tsd_ref_;
|
uint32_t tsd_ref_;
|
||||||
|
|
|
@ -277,6 +277,11 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
|
||||||
for (auto &prim : cut_list_) {
|
for (auto &prim : cut_list_) {
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
if (prim->name() == node_prim->name()) {
|
if (prim->name() == node_prim->name()) {
|
||||||
|
if (prim->name() == prim::kPrimBpropCut->name()) {
|
||||||
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
ms_context->set_enable_pynative_hook(true);
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue