!65651 Concat 动态shape问题修复

Merge pull request !65651 from zhengyafei/r2.3
This commit is contained in:
i-robot 2024-02-29 03:26:18 +00:00 committed by Gitee
commit 8275373712
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 11 additions and 4 deletions

View File

@ -27,3 +27,5 @@ deeplabv3_vocaug_bs_1_ascend -1
resnet50_imagenet_bs_1_Ascend 3
resnet50_imagenet_bs_1_GPU 3
unet3d_luna16_bs_1_ascend
fasterrcnn_coco2017_bs_1_ascend -1
lstm_aclimdb_bs_64_Ascend -1

View File

@ -76,7 +76,10 @@ static STATUS AdapteNodeWithDynamicInput(const CNodePtr &cnode) {
cnode->set_inputs(new_inputs);
// add kAttrDynInputSizes for multi-input operator.
int64_t input_num = tuple_node->size() - 1;
prim->AddAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{input_num, -1}));
auto dst_prim = prim->Clone();
dst_prim->AddAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{input_num, -1}));
ValueNodePtr value_node = cnode->input(0)->cast<ValueNodePtr>();
value_node->set_value(dst_prim);
return lite::RET_OK;
}

View File

@ -46,7 +46,7 @@ bool ArgsToAttrPass::Run(const FuncGraphPtr &func_graph) {
if (prim == nullptr) {
continue;
}
auto dst_prim = prim->Clone();
auto node_inputs = cnode->inputs();
std::vector<AnfNodePtr> new_node_inputs;
@ -85,10 +85,12 @@ bool ArgsToAttrPass::Run(const FuncGraphPtr &func_graph) {
}
auto arg_value_node = arg_input_node->cast<ValueNodePtr>();
auto arg_value = arg_value_node->value();
prim->AddAttr(arg.arg_name_, arg_value);
dst_prim->AddAttr(arg.arg_name_, arg_value);
ValueNodePtr value_node = cnode->input(0)->cast<ValueNodePtr>();
value_node->set_value(dst_prim);
}
auto new_node = func_graph->NewCNode(prim, new_node_inputs);
auto new_node = func_graph->NewCNode(dst_prim, new_node_inputs);
new_node->set_abstract(node->abstract());
new_node->set_fullname_with_scope(node->fullname_with_scope());