diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 13ab323ac91..0b6c5f7354f 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -58,17 +58,24 @@ mindspore::HashSet prims_to_skip_undetermined_infer{ // We consider all tuple/list arguments are used by now. // Should check 'tuple argument index' and 'element use index' later. mindspore::HashSet prims_use_sequence_elements{prim::kPrimStack->name(), - prim::kPrimBroadcast->name(), prim::kPrimConcat->name(), prim::kPrimTupleToArray->name(), prim::kPrimPack->name(), prim::kPrimSlice->name(), prim::kPrimStridedSlice->name(), prim::kPrimScatterNd->name(), + prim::kPrimReshape->name(), + prim::kPrimTile->name(), + prim::kPrimConv3DBackpropFilter->name(), + prim::kPrimCentralization->name(), + prim::kPrimMerge->name(), + prim::kPrimCustom->name(), + prim::kPrimAssert->name(), "InvertPermutation", "Meshgrid", "TransShape", - "ParallelConcat"}; + "ParallelConcat", + "CudnnGRU"}; EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &out_conf) { diff --git a/mindspore/core/ops/dynamic_resize_nearest_neighbor.cc b/mindspore/core/ops/dynamic_resize_nearest_neighbor.cc index d5e5383aaa7..17d9a773225 100644 --- a/mindspore/core/ops/dynamic_resize_nearest_neighbor.cc +++ b/mindspore/core/ops/dynamic_resize_nearest_neighbor.cc @@ -122,7 +122,14 @@ AbstractBasePtr DynamicResizeNearestNeighborInfer(const abstract::AnalysisEngine const int64_t input_num = 2; (void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(input_args)), kEqual, input_num, prim_name); - return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); + auto res = abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); + // Set all used flags of tuple as true. + for (size_t i = 0; i < input_args.size(); i++) { + if (input_args[i] != nullptr) { + SetSequenceElementsUseFlags(input_args[i], true); + } + } + return res; } REGISTER_PRIMITIVE_EVAL_IMPL(DynamicResizeNearestNeighbor, prim::kPrimDynamicResizeNearestNeighbor, DynamicResizeNearestNeighborInfer, nullptr, true); diff --git a/mindspore/core/ops/grad/avg_pool_3d_grad.cc b/mindspore/core/ops/grad/avg_pool_3d_grad.cc index 582c0c3a33b..c4454156f16 100644 --- a/mindspore/core/ops/grad/avg_pool_3d_grad.cc +++ b/mindspore/core/ops/grad/avg_pool_3d_grad.cc @@ -62,8 +62,15 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)->shape()); + auto res = std::make_shared(InferType(primitive, input_args), + InferShape(primitive, input_args)->shape()); + // Set all used flags of tuple as true. + for (size_t i = 0; i < input_args.size(); i++) { + if (input_args[i] != nullptr) { + SetSequenceElementsUseFlags(input_args[i], true); + } + } + return res; } REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool3DGrad, prim::kPrimAvgPool3DGrad, AvgPool3DGradInfer, nullptr, true); diff --git a/mindspore/core/ops/grad/conv2d_backprop_filter.cc b/mindspore/core/ops/grad/conv2d_backprop_filter.cc index 79836e22a7d..548381a2680 100644 --- a/mindspore/core/ops/grad/conv2d_backprop_filter.cc +++ b/mindspore/core/ops/grad/conv2d_backprop_filter.cc @@ -239,8 +239,15 @@ AbstractBasePtr Conv2DBackpropFilterInfer(const abstract::AnalysisEnginePtr &, c for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - return std::make_shared(Conv2DBackpropFilterInferType(primitive, input_args), - Conv2DBackpropFilterInferShape(primitive, input_args)); + auto res = std::make_shared(Conv2DBackpropFilterInferType(primitive, input_args), + Conv2DBackpropFilterInferShape(primitive, input_args)); + // Set all used flags of tuple as true. + for (size_t i = 0; i < input_args.size(); i++) { + if (input_args[i] != nullptr) { + SetSequenceElementsUseFlags(input_args[i], true); + } + } + return res; } REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropFilter, prim::kPrimConv2DBackpropFilter, Conv2DBackpropFilterInfer, nullptr, true); diff --git a/mindspore/core/ops/grad/dropout_grad.cc b/mindspore/core/ops/grad/dropout_grad.cc index e68e24eb530..b5432008564 100644 --- a/mindspore/core/ops/grad/dropout_grad.cc +++ b/mindspore/core/ops/grad/dropout_grad.cc @@ -53,7 +53,14 @@ AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const Prim const std::set valid_types = {kFloat16, kFloat32}; auto out_type = CheckAndConvertUtils::CheckTensorTypeValid("x", dy_type, valid_types, op_name); auto shape = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, dy_index); - return abstract::MakeAbstract(shape, out_type); + auto res = abstract::MakeAbstract(shape, out_type); + // Set all used flags of tuple as true. + for (size_t i = 0; i < input_args.size(); i++) { + if (input_args[i] != nullptr) { + SetSequenceElementsUseFlags(input_args[i], true); + } + } + return res; } REGISTER_PRIMITIVE_EVAL_IMPL(DropoutGrad, prim::kPrimDropoutGrad, DropoutGradInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/grad/strided_slice_grad.cc b/mindspore/core/ops/grad/strided_slice_grad.cc index 2cb671e24fc..80208ffbe19 100644 --- a/mindspore/core/ops/grad/strided_slice_grad.cc +++ b/mindspore/core/ops/grad/strided_slice_grad.cc @@ -120,8 +120,15 @@ AbstractBasePtr StridedSliceGradInfer(const abstract::AnalysisEnginePtr &, const MS_EXCEPTION_IF_NULL(primitive); const int64_t input_num = 5; CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); - return abstract::MakeAbstract(StridedSliceGradInferShape(primitive, input_args), - StridedSliceGradInferType(primitive, input_args)); + auto res = abstract::MakeAbstract(StridedSliceGradInferShape(primitive, input_args), + StridedSliceGradInferType(primitive, input_args)); + // Set all used flags of tuple as true. + for (size_t i = 0; i < input_args.size(); i++) { + if (input_args[i] != nullptr) { + SetSequenceElementsUseFlags(input_args[i], true); + } + } + return res; } void StridedSliceGrad::set_begin_mask(int64_t begin_mask) { diff --git a/mindspore/core/ops/neighborexchange.cc b/mindspore/core/ops/neighborexchange.cc index edae7a7c9bb..20b4ef2dc46 100644 --- a/mindspore/core/ops/neighborexchange.cc +++ b/mindspore/core/ops/neighborexchange.cc @@ -185,7 +185,13 @@ AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const Check(primitive, input_args); auto type = InferType(primitive); auto shape = InferShape(primitive); - return abstract::MakeAbstract(shape, type); + auto res = abstract::MakeAbstract(shape, type); + for (size_t i = 0; i < input_args.size(); i++) { + if (input_args[i] != nullptr) { + SetSequenceElementsUseFlags(input_args[i], true); + } + } + return res; } REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/ones.cc b/mindspore/core/ops/ones.cc index 07a3799f41f..2141a0fda1d 100644 --- a/mindspore/core/ops/ones.cc +++ b/mindspore/core/ops/ones.cc @@ -53,8 +53,14 @@ AbstractBasePtr OnesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt const std::string op_name = primitive->name(); const int64_t input_num = 2; CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name); - - return abstract::MakeAbstract(OnesInferShape(primitive, input_args), OnesInferType(primitive, input_args)); + auto res = abstract::MakeAbstract(OnesInferShape(primitive, input_args), OnesInferType(primitive, input_args)); + // Set all used flags of tuple as true. + for (size_t i = 0; i < input_args.size(); i++) { + if (input_args[i] != nullptr) { + SetSequenceElementsUseFlags(input_args[i], true); + } + } + return res; } ValuePtr OnesInferValue(const PrimitivePtr &prim, const std::vector &input_args) { diff --git a/mindspore/core/ops/transpose.cc b/mindspore/core/ops/transpose.cc index b7ecd487e05..64d185a7b00 100644 --- a/mindspore/core/ops/transpose.cc +++ b/mindspore/core/ops/transpose.cc @@ -97,7 +97,14 @@ AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const Primit primitive->name()); auto type = InferType(primitive, input_args); auto shape = InferShape(primitive, input_args); - return abstract::MakeAbstract(shape, type); + auto res = abstract::MakeAbstract(shape, type); + // Set all used flags of tuple as true. + for (size_t i = 0; i < input_args.size(); i++) { + if (input_args[i] != nullptr) { + SetSequenceElementsUseFlags(input_args[i], true); + } + } + return res; } REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/zeros.cc b/mindspore/core/ops/zeros.cc index c29f86e6c96..fc4cb77580f 100644 --- a/mindspore/core/ops/zeros.cc +++ b/mindspore/core/ops/zeros.cc @@ -52,7 +52,14 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - return abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args)); + auto res = abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args)); + // Set all used flags of tuple as true. + for (size_t i = 0; i < input_args.size(); i++) { + if (input_args[i] != nullptr) { + SetSequenceElementsUseFlags(input_args[i], true); + } + } + return res; } ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector &input_args) {