!29326 fix DynamicStitch FlattenGrad VitualXX Fill primitive tuple input incorrect elimination.

Merge pull request !29326 from huanghui/tuple-input
This commit is contained in:
i-robot 2022-01-20 11:17:39 +00:00 committed by Gitee
commit 3070cfcb5d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 15 additions and 0 deletions

View File

@ -90,6 +90,16 @@ mindspore::HashSet<std::string> prims_use_sequence_elements{
prim::kPrimSparseTensorDenseMatmul->name(),
prim::kPrimBroadcast->name(),
prim::kPrimEinsumGrad->name(),
prim::kPrimFlattenGrad->name(),
prim::kPrimMirror->name(),
prim::kPrimMirrorMiniStep->name(),
prim::kPrimMiniStepAllGather->name(),
prim::kPrimMicroStepAllGather->name(),
prim::kPrimVirtualDiv->name(),
prim::kPrimVirtualAdd->name(),
prim::kPrimVirtualDataset->name(),
prim::kPrimVirtualOutput->name(),
prim::kPrimFill->name(),
"InvertPermutation",
"Meshgrid",
"TransShape",

View File

@ -1369,6 +1369,8 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv
min_shape[0] = 1;
max_shape[0] = indices_total_size * EXPAND_MAX;
}
SetSequenceElementsUseFlags(input_tuple, true);
SetSequenceElementsUseFlags(input_tuple_1, true);
return std::make_shared<AbstractTensor>(infer_type,
std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape));
}
@ -1379,6 +1381,9 @@ AbstractBasePtr InferImplTensorCopySlices(const AnalysisEnginePtr &, const Primi
constexpr auto kTensorCopySlicesInputNum = 5;
CheckArgsSize(op_name, args_spec_list, kTensorCopySlicesInputNum);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
for (size_t i = 2; i < args_spec_list.size(); ++i) {
SetSequenceElementsUseFlags(args_spec_list[i], true);
}
return std::make_shared<AbstractTensor>(input->element(), input->shape());
}
} // namespace abstract