forked from mindspore-Ecosystem/mindspore
!29326 fix DynamicStitch FlattenGrad VitualXX Fill primitive tuple input incorrect elimination.
Merge pull request !29326 from huanghui/tuple-input
This commit is contained in:
commit
3070cfcb5d
|
@ -90,6 +90,16 @@ mindspore::HashSet<std::string> prims_use_sequence_elements{
|
|||
prim::kPrimSparseTensorDenseMatmul->name(),
|
||||
prim::kPrimBroadcast->name(),
|
||||
prim::kPrimEinsumGrad->name(),
|
||||
prim::kPrimFlattenGrad->name(),
|
||||
prim::kPrimMirror->name(),
|
||||
prim::kPrimMirrorMiniStep->name(),
|
||||
prim::kPrimMiniStepAllGather->name(),
|
||||
prim::kPrimMicroStepAllGather->name(),
|
||||
prim::kPrimVirtualDiv->name(),
|
||||
prim::kPrimVirtualAdd->name(),
|
||||
prim::kPrimVirtualDataset->name(),
|
||||
prim::kPrimVirtualOutput->name(),
|
||||
prim::kPrimFill->name(),
|
||||
"InvertPermutation",
|
||||
"Meshgrid",
|
||||
"TransShape",
|
||||
|
|
|
@ -1369,6 +1369,8 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv
|
|||
min_shape[0] = 1;
|
||||
max_shape[0] = indices_total_size * EXPAND_MAX;
|
||||
}
|
||||
SetSequenceElementsUseFlags(input_tuple, true);
|
||||
SetSequenceElementsUseFlags(input_tuple_1, true);
|
||||
return std::make_shared<AbstractTensor>(infer_type,
|
||||
std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape));
|
||||
}
|
||||
|
@ -1379,6 +1381,9 @@ AbstractBasePtr InferImplTensorCopySlices(const AnalysisEnginePtr &, const Primi
|
|||
constexpr auto kTensorCopySlicesInputNum = 5;
|
||||
CheckArgsSize(op_name, args_spec_list, kTensorCopySlicesInputNum);
|
||||
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
for (size_t i = 2; i < args_spec_list.size(); ++i) {
|
||||
SetSequenceElementsUseFlags(args_spec_list[i], true);
|
||||
}
|
||||
return std::make_shared<AbstractTensor>(input->element(), input->shape());
|
||||
}
|
||||
} // namespace abstract
|
||||
|
|
Loading…
Reference in New Issue