forked from mindspore-Ecosystem/mindspore
!2153 Add transdata for output node in pynative hook mode
Merge pull request !2153 from JoyLvliang/pynative-insert-transdata-for-hook-mode
This commit is contained in:
commit
8903b50042
|
@ -51,7 +51,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
|
|||
AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->execution_mode() == kPynativeMode) {
|
||||
if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) {
|
||||
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
|
||||
return new_node;
|
||||
}
|
||||
|
|
|
@ -74,6 +74,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
precompile_only_ = false;
|
||||
auto_mixed_precision_flag_ = false;
|
||||
enable_pynative_infer_ = false;
|
||||
enable_pynative_hook_ = false;
|
||||
enable_dynamic_mem_pool_ = true;
|
||||
graph_memory_max_size_ = "0";
|
||||
variable_memory_max_size_ = "0";
|
||||
|
|
|
@ -64,6 +64,9 @@ class MsContext {
|
|||
bool enable_pynative_infer() const { return enable_pynative_infer_; }
|
||||
void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; }
|
||||
|
||||
bool enable_pynative_hook() const { return enable_pynative_hook_; }
|
||||
void set_enable_pynative_hook(bool enable_pynative_hook) { enable_pynative_hook_ = enable_pynative_hook; }
|
||||
|
||||
bool enable_task_sink() const { return enable_task_sink_; }
|
||||
|
||||
void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; }
|
||||
|
@ -161,6 +164,7 @@ class MsContext {
|
|||
uint32_t device_id_;
|
||||
int execution_mode_;
|
||||
bool enable_pynative_infer_;
|
||||
bool enable_pynative_hook_;
|
||||
bool save_graphs_flag_;
|
||||
std::string save_graphs_path_;
|
||||
uint32_t tsd_ref_;
|
||||
|
|
|
@ -277,6 +277,11 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
|
|||
for (auto &prim : cut_list_) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->name() == node_prim->name()) {
|
||||
if (prim->name() == prim::kPrimBpropCut->name()) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
ms_context->set_enable_pynative_hook(true);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue