forked from mindspore-Ecosystem/mindspore
!11159 fix inceptionv3 kernel build error in pynative
From: @chujinjin Reviewed-by: @kisnwang,@zhoufeng54 Signed-off-by: @zhoufeng54
This commit is contained in:
commit
eb51ce93bd
|
@ -860,6 +860,7 @@ void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_inf
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Run op " << op_run_info->op_name << " start!";
|
||||
// malloc mem
|
||||
RunOpRemoveNopNode(graph);
|
||||
RunOpMemoryAlloc(*input_tensors, graph.get());
|
||||
// Build dynamic kernel
|
||||
if (op_run_info->is_dynamic_shape) {
|
||||
|
@ -1018,6 +1019,7 @@ void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
|
||||
void AscendSession::RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
RunOpHideNopNode(kernel_graph);
|
||||
// Insert CLearZero op
|
||||
// prepare for next step from json get atomic info
|
||||
BuildKernel(kernel_graph);
|
||||
|
|
|
@ -418,6 +418,7 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
|
|||
SelectKernel(kernel_graph);
|
||||
RunOpHardwareOptimize(kernel_graph);
|
||||
StartKernelRT();
|
||||
RunOpHideNopNode(kernel_graph);
|
||||
BuildKernel(kernel_graph);
|
||||
run_op_graphs_[graph_info] = kernel_graph;
|
||||
}
|
||||
|
@ -432,6 +433,7 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
|
|||
// run op
|
||||
auto kernel_graph = run_op_graphs_[graph_info];
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
RunOpRemoveNopNode(kernel_graph);
|
||||
RunOpAllocateMemory(*input_tensors, kernel_graph.get());
|
||||
// Execute the computation
|
||||
LoadInputData(kernel_graph, *input_tensors);
|
||||
|
|
|
@ -1174,7 +1174,8 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
auto &node = item.second.first;
|
||||
auto &output_index = item.second.second;
|
||||
DeviceAddressPtr address = nullptr;
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
|
||||
ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
|
||||
address = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
|
||||
} else {
|
||||
address = AnfAlgo::GetMutableOutputAddr(node, output_index);
|
||||
|
@ -1761,6 +1762,22 @@ bool SessionBasic::IsGetNextGraph(const GraphId &graph_id, std::string *channel_
|
|||
return false;
|
||||
}
|
||||
|
||||
void SessionBasic::RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
|
||||
opt::RemoveNopNode(kernel_graph.get());
|
||||
}
|
||||
}
|
||||
|
||||
void SessionBasic::RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
|
||||
opt::HideNopNode(kernel_graph.get());
|
||||
}
|
||||
}
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
|
||||
if (!ps::Util::IsRoleOfWorker()) {
|
||||
|
|
|
@ -194,6 +194,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list);
|
||||
void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph);
|
||||
void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs);
|
||||
void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const;
|
||||
void RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const;
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const;
|
||||
void GetBatchElements(const AnfNodePtr &kernel_node) const;
|
||||
|
|
Loading…
Reference in New Issue