!11159 fix inceptionv3 kernel build error in pynative

From: @chujinjin
Reviewed-by: @kisnwang,@zhoufeng54
Signed-off-by: @zhoufeng54
This commit is contained in:
mindspore-ci-bot 2021-01-11 21:11:41 +08:00 committed by Gitee
commit eb51ce93bd
4 changed files with 24 additions and 1 deletions

View File

@ -860,6 +860,7 @@ void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_inf
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Run op " << op_run_info->op_name << " start!";
// malloc mem
RunOpRemoveNopNode(graph);
RunOpMemoryAlloc(*input_tensors, graph.get());
// Build dynamic kernel
if (op_run_info->is_dynamic_shape) {
@ -1018,6 +1019,7 @@ void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_grap
void AscendSession::RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
RunOpHideNopNode(kernel_graph);
// Insert CLearZero op
// prepare for next step from json get atomic info
BuildKernel(kernel_graph);

View File

@ -418,6 +418,7 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
SelectKernel(kernel_graph);
RunOpHardwareOptimize(kernel_graph);
StartKernelRT();
RunOpHideNopNode(kernel_graph);
BuildKernel(kernel_graph);
run_op_graphs_[graph_info] = kernel_graph;
}
@ -432,6 +433,7 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
// run op
auto kernel_graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(kernel_graph);
RunOpRemoveNopNode(kernel_graph);
RunOpAllocateMemory(*input_tensors, kernel_graph.get());
// Execute the computation
LoadInputData(kernel_graph, *input_tensors);

View File

@ -1174,7 +1174,8 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
auto &node = item.second.first;
auto &output_index = item.second.second;
DeviceAddressPtr address = nullptr;
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
address = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
} else {
address = AnfAlgo::GetMutableOutputAddr(node, output_index);
@ -1761,6 +1762,22 @@ bool SessionBasic::IsGetNextGraph(const GraphId &graph_id, std::string *channel_
return false;
}
void SessionBasic::RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
opt::RemoveNopNode(kernel_graph.get());
}
}
void SessionBasic::RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
opt::HideNopNode(kernel_graph.get());
}
}
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
if (!ps::Util::IsRoleOfWorker()) {

View File

@ -194,6 +194,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list);
void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph);
void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs);
void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const;
void RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const;
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const;
void GetBatchElements(const AnfNodePtr &kernel_node) const;