fix dynamic parameters

This commit is contained in:
jjfeing 2021-06-24 19:09:55 +08:00
parent 5ee636c519
commit 851f4b46dd
2 changed files with 5 additions and 6 deletions

View File

@ -222,6 +222,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
const bool need_padding, const std::string &op_name, const std::vector<int64_t> &perm) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(input);
MS_EXCEPTION_IF_NULL(kernel_select);
CNodePtr trans_node = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(op_name)), input});
MS_EXCEPTION_IF_NULL(trans_node);
if (need_padding) {
@ -243,12 +244,10 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
if (op_name == prim::kPrimTranspose->name()) {
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(perm), trans_node);
}
MS_EXCEPTION_IF_NULL(kernel_select);
kernel_select->SelectKernel(trans_node);
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node);
AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), trans_node);
MS_EXCEPTION_IF_NULL(trans_node);
trans_node->set_scope(input->scope());
kernel_select->SelectKernel(trans_node);
return trans_node;
}

View File

@ -194,6 +194,9 @@ static bool IsAtomicNode(const CNodePtr &kernel_node) {
if (parameters_indexs.empty()) {
return false;
}
if (AnfAlgo::IsDynamicShape(kernel_node)) {
parameters_indexs.pop_back();
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size();
@ -201,9 +204,6 @@ static bool IsAtomicNode(const CNodePtr &kernel_node) {
size_t total_num = input_num + workspace_num + output_num;
size_t pad_index = param_num;
if (AnfAlgo::IsDynamicShape(kernel_node)) {
parameters_indexs.pop_back();
}
for (; pad_index < total_num; ++pad_index) {
parameters_indexs.emplace_back(0);
}