From 851f4b46dd7553a56fa0ec8a81d00ad4f56a1ea1 Mon Sep 17 00:00:00 2001 From: jjfeing Date: Thu, 24 Jun 2021 19:09:55 +0800 Subject: [PATCH] fix dynamic parameters --- mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc | 5 ++--- .../ccsrc/runtime/device/ascend/kernel_build_ascend.cc | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index c52dad48952..e39c50745b1 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -222,6 +222,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const bool need_padding, const std::string &op_name, const std::vector &perm) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(kernel_select); CNodePtr trans_node = func_graph->NewCNode({NewValueNode(std::make_shared(op_name)), input}); MS_EXCEPTION_IF_NULL(trans_node); if (need_padding) { @@ -243,12 +244,10 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, if (op_name == prim::kPrimTranspose->name()) { AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(perm), trans_node); } - MS_EXCEPTION_IF_NULL(kernel_select); - kernel_select->SelectKernel(trans_node); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node); AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue>({}), trans_node); - MS_EXCEPTION_IF_NULL(trans_node); trans_node->set_scope(input->scope()); + kernel_select->SelectKernel(trans_node); return trans_node; } diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc index 179f203757c..8be3d5b4749 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc @@ -194,6 +194,9 @@ static bool IsAtomicNode(const CNodePtr &kernel_node) { if (parameters_indexs.empty()) { return false; } + if (AnfAlgo::IsDynamicShape(kernel_node)) { + parameters_indexs.pop_back(); + } size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size(); @@ -201,9 +204,6 @@ static bool IsAtomicNode(const CNodePtr &kernel_node) { size_t total_num = input_num + workspace_num + output_num; size_t pad_index = param_num; - if (AnfAlgo::IsDynamicShape(kernel_node)) { - parameters_indexs.pop_back(); - } for (; pad_index < total_num; ++pad_index) { parameters_indexs.emplace_back(0); }