From 0deb008237caa909bbad549a2d401e45b36a444d Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Thu, 20 Jan 2022 12:35:22 +0800 Subject: [PATCH] fixed 4e7c012 from https://gitee.com/zh_qh/mindspore/pulls/28963 Handle the primitives who transparent pass or partial use tuple/list. --- .../operator/ops_front_infer_function.cc | 10 -- .../pipeline/jit/static_analysis/prim.cc | 151 +++++++++--------- mindspore/core/abstract/infer_functions.h | 1 - mindspore/core/abstract/prim_arrays.cc | 9 -- mindspore/core/abstract/prim_others.cc | 12 -- mindspore/core/abstract/prim_statement.cc | 3 - mindspore/core/abstract/prim_structures.cc | 3 - mindspore/core/ops/accumulate_n_v2.cc | 1 - mindspore/core/ops/addn.cc | 5 +- mindspore/core/ops/ctc_loss_v2.cc | 1 - mindspore/core/ops/ctc_loss_v2_grad.cc | 1 - .../ops/dynamic_resize_nearest_neighbor.cc | 3 - mindspore/core/ops/einsum.cc | 2 - mindspore/core/ops/grad/avg_pool_3d_grad.cc | 3 - .../core/ops/grad/conv2d_backprop_filter.cc | 2 - mindspore/core/ops/grad/dropout_grad.cc | 3 - mindspore/core/ops/grad/strided_slice_grad.cc | 9 -- mindspore/core/ops/lp_norm.cc | 1 - mindspore/core/ops/neighborexchange.cc | 8 +- mindspore/core/ops/ones.cc | 6 +- mindspore/core/ops/transpose.cc | 9 +- mindspore/core/ops/zeros.cc | 6 +- 22 files changed, 81 insertions(+), 168 deletions(-) diff --git a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc index ab09f392b95..f664d8ec7e5 100644 --- a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc +++ b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc @@ -51,9 +51,6 @@ AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const Abst CheckArgsSize(op_name, args_spec_list, 2); auto input_x = CheckArg(op_name, args_spec_list, 0); auto input_y = CheckArg(op_name, args_spec_list, 1); - SetSequenceElementsUseFlags(input_x, true); - SetSequenceElementsUseFlags(input_y, true); - ValuePtr x_value = input_x->BuildValue(); ValuePtr y_value = input_y->BuildValue(); return std::make_shared(*x_value == *y_value); @@ -298,8 +295,6 @@ AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const return std::make_shared(elem_list); } - SetSequenceElementsUseFlags(arg_x, true); - SetSequenceElementsUseFlags(arg_y, true); return BroadcastGradientArgsDiff(x_shape, y_shape); } @@ -354,7 +349,6 @@ AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const Primi MS_EXCEPTION_IF_NULL(result2); MS_EXCEPTION_IF_NULL(result1->abstract()); MS_EXCEPTION_IF_NULL(result2->abstract()); - SetSequenceElementsUseFlags(lst, true); return result1->abstract()->Join(result2->abstract()); } @@ -370,7 +364,6 @@ AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const Primitiv AbstractBasePtrList elem_list; (void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list), [](const AbstractBasePtr &elem) { return elem->Clone(); }); - SetSequenceElementsUseFlags(input, true); return std::make_shared(elem_list); } @@ -412,7 +405,6 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP } auto axis_value_ptr = axis_value->cast(); MS_EXCEPTION_IF_NULL(axis_value_ptr); - SetSequenceElementsUseFlags(shape_x, true); return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive); } @@ -477,8 +469,6 @@ AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr auto result_v = MakeValue(result); values.push_back(std::make_shared(result_v, result_v->type())); } - SetSequenceElementsUseFlags(shape_x, true); - SetSequenceElementsUseFlags(div_shp, true); return std::make_shared(values); } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 01d66244f0e..76831582b2f 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -52,70 +52,20 @@ namespace abstract { using mindspore::parse::PyObjectWrapper; mindspore::HashSet prims_to_skip_undetermined_infer{ - "MakeTuple", "make_list", "Switch", "env_setitem", "env_getitem", "Load", "UpdateState"}; + prim::kPrimMakeTuple->name(), prim::kPrimMakeList->name(), prim::kPrimSwitch->name(), + prim::kPrimEnvironSet->name(), prim::kPrimEnvironGet->name(), prim::kPrimLoad->name(), + prim::kPrimUpdateState->name()}; -// The Python primitives who use tuple/list elements. -// We consider all tuple/list arguments are used by now. -// Should check 'tuple argument index' and 'element use index' later. -mindspore::HashSet prims_use_sequence_elements{ - prim::kPrimStack->name(), - prim::kPrimConcat->name(), - prim::kPrimTupleToArray->name(), - prim::kPrimPack->name(), - prim::kPrimSlice->name(), - prim::kPrimStridedSlice->name(), - prim::kPrimScatterNd->name(), - prim::kPrimReshape->name(), - prim::kPrimTile->name(), - prim::kPrimConv3DBackpropFilter->name(), - prim::kPrimCentralization->name(), - prim::kPrimMerge->name(), - prim::kPrimCustom->name(), - prim::kPrimAssert->name(), - prim::kPrimReduceMean->name(), - prim::kPrimReduceSum->name(), - prim::kPrimReduceAll->name(), - prim::kPrimReduceAny->name(), - prim::kPrimReduceMax->name(), - prim::kPrimReduceMin->name(), - prim::kPrimLstm->name(), - prim::kPrimConv3DBackpropInput->name(), - prim::kPrimCheckBprop->name(), - prim::kPrimPush->name(), - prim::kPrimIdentity->name(), - prim::kPrimPyFunc->name(), - prim::kPrimStandardNormal->name(), - prim::kPrimUniformReal->name(), - prim::kPrimSparseToDense->name(), - prim::kPrimSparseTensorDenseMatmul->name(), - prim::kPrimBroadcast->name(), - prim::kPrimEinsumGrad->name(), - prim::kPrimFlattenGrad->name(), - prim::kPrimMirror->name(), - prim::kPrimMirrorMiniStep->name(), - prim::kPrimMiniStepAllGather->name(), - prim::kPrimMicroStepAllGather->name(), - prim::kPrimVirtualDiv->name(), - prim::kPrimVirtualAdd->name(), - prim::kPrimVirtualDataset->name(), - prim::kPrimVirtualOutput->name(), - prim::kPrimFill->name(), - "InvertPermutation", - "Meshgrid", - "TransShape", - "ParallelConcat", - "CudnnGRU", - "GetModel", - "UpdateModel", - "ReduceProd", - "StandardLaplace", - "Gamma", - "Poisson", - "UniformInt", - "BufferSample", - "BufferAppend", - "BufferGetItem", -}; +// The Python primitives who visit tuple/list elements, but not consume all elements. +// Including: +// - Consume no element. For instance, MakeTuple. +// - Consume partial elements, not all. For instance, TupleGetItem. +// Map{"primitive name", {vector:"index to transparent pass, -1 means all elements"}} +mindspore::HashMap> prims_transparent_pass_sequence{ + {prim::kPrimReturn->name(), std::vector({0})}, {prim::kPrimDepend->name(), std::vector({0})}, + {prim::kPrimIdentity->name(), std::vector({0})}, {prim::kPrimMakeTuple->name(), std::vector({-1})}, + {prim::kPrimMakeList->name(), std::vector({-1})}, {prim::kPrimListAppend->name(), std::vector({0})}, + {prim::kPrimTupleGetItem->name(), std::vector({0})}, {prim::kPrimListGetItem->name(), std::vector({0})}}; EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &out_conf) { @@ -689,13 +639,14 @@ void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBa AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) { // Convert to AbstractValue based on type and shape - auto out_dtype = output[ATTR_DTYPE]; if (output[ATTR_VALUE].is_none()) { return MakePyInferRes2Abstract(output); } + // Convert pyobject to Value, then to AbstractValue - ValuePtr converted_ret = nullptr; + auto out_dtype = output[ATTR_DTYPE]; TypePtr dtype = py::isinstance(out_dtype) ? out_dtype.cast() : nullptr; + ValuePtr converted_ret = nullptr; bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype); if (!converted) { MS_LOG(EXCEPTION) << "Convert data failed"; @@ -804,7 +755,61 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en return eval_result; } +namespace { +void CheckSequenceArgumentForCppPrimitive(const PrimitivePtr &prim, const AbstractBasePtrList &args) { + // To check tuple/list operations with a white list of Python primitive. + auto iter = prims_transparent_pass_sequence.find(prim->name()); + if (iter == prims_transparent_pass_sequence.end()) { + // The primitive use all elements of each argument. + for (size_t i = 0; i < args.size(); ++i) { + if (args[i]->isa()) { + MS_LOG(DEBUG) << "Primitive \'" << prim->name() << "\' is consuming tuple/list arguments[" << i + << "]: " << args[i]->ToString(); + SetSequenceElementsUseFlags(args[i], true); + } + } + return; + } + + // It's transparent pass primitive or using partial elements primitive. + auto index_list = iter->second; + if (index_list.empty()) { + MS_LOG(EXCEPTION) << "The primitive list should not be empty for " << prim->name(); + } + // Ignore all arguments, no need checking if AbstractSequence. + if (index_list[0] == -1) { + return; + } + // Check the specific arguments index. + for (size_t i = 0; i < args.size(); ++i) { + if (!args[i]->isa()) { + continue; + } + if (std::find(index_list.begin(), index_list.end(), i) == index_list.end()) { + // For current tuple/list argument, it's not a primitive of total transparent pass or partial element use. + MS_LOG(DEBUG) << "Primitive \'" << prim->name() << "\' is consuming specific tuple/list arguments[" << i + << "]: " << args[i]->ToString(); + SetSequenceElementsUseFlags(args[i], true); + } + } +} + +void CheckSequenceArgumentForPythonPrimitive(const PrimitivePtr &prim, const AbstractBasePtrList &args) { + // Consider all primitive implemented python infer() real use the tuple/list arguments. + for (size_t i = 0; i < args.size(); ++i) { + if (args[i]->isa()) { + MS_LOG(DEBUG) << "Primitive \'" << prim->name() << "\' is consuming tuple/list arguments[" << i + << "]: " << args[i]->ToString(); + SetSequenceElementsUseFlags(args[i], true); + } + } +} +} // namespace + EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { + // To check tuple/list operations with a white list of Python primitive. + CheckSequenceArgumentForCppPrimitive(prim_, args); + if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) { auto ret_abstract = AbstractEval(args); if (ret_abstract != nullptr) { @@ -846,6 +851,9 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c } EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { + // Consider all primitive implemented python infer() real use the tuple/list arguments. + CheckSequenceArgumentForPythonPrimitive(prim_py_, args); + // Ensure input arguments are evaluated. auto ret_abstract = AbstractEval(args); if (ret_abstract != nullptr) { @@ -875,9 +883,9 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, con prim_py_->EndRecordAddAttr(); const auto &added_attrs = prim_py_->evaluate_added_attrs(); MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); - auto res_spec = PyInferRes2Abstract(prim_py_, output); - MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; - eval_result = std::make_shared(res_spec, std::make_shared(added_attrs)); + auto res_abs = PyInferRes2Abstract(prim_py_, output); + MS_LOG(DEBUG) << "Python InferTensor result abstract: " << res_abs->ToString(); + eval_result = std::make_shared(res_abs, std::make_shared(added_attrs)); // Save result to global primitive eval cache. if (enable_global_cache) { eval_cache_->Put(prim_py_, std::move(input_attrs), args, eval_result); @@ -886,13 +894,6 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, con // Save result to evaluator cache. evaluator_cache_mgr_->SetValue(args, eval_result); } - // To check tuple/list operations with a white list of Python primitive. - if (prims_use_sequence_elements.find(prim_py_->name()) != prims_use_sequence_elements.end()) { - // Set all used flags of tuple as true. - for (auto &arg : args) { - SetSequenceElementsUseFlags(arg, true); - } - } return eval_result; } diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 93ab35609b3..8360b8b295e 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -303,7 +303,6 @@ AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const Abst // Inputs: a tuple or list or dict. CheckArgsSize(op_name, args_spec_list, 1); auto arg = CheckArg(op_name, args_spec_list, 0); - SetSequenceElementsUseFlags(arg, true); return std::make_shared(SizeToLong(arg->size())); } } // namespace abstract diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index f8f6e669175..370b8f34570 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -115,8 +115,6 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti (void)std::transform(res.begin(), res.end(), std::back_inserter(elems), [](int64_t n) -> AbstractBasePtr { return std::make_shared(std::make_shared(n), kInt64); }); - SetSequenceElementsUseFlags(xs, true); - SetSequenceElementsUseFlags(ys, true); return std::make_shared(elems); } @@ -137,8 +135,6 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr arg = CheckArg(op_name, args_spec_list, 0); tuple_len = arg->elements().size(); tensor_base = CheckArg(op_name, arg->elements(), 0); - // For Stack(tuple), set all used flags of tuple as true. - SetSequenceElementsUseFlags(args_spec_list[0], true); } else if (args_spec_list[0]->isa()) { tuple_len = args_spec_list.size(); tensor_base = CheckArg(op_name, args_spec_list, 0); @@ -1369,8 +1365,6 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv min_shape[0] = 1; max_shape[0] = indices_total_size * EXPAND_MAX; } - SetSequenceElementsUseFlags(input_tuple, true); - SetSequenceElementsUseFlags(input_tuple_1, true); return std::make_shared(infer_type, std::make_shared(out_shape, min_shape, max_shape)); } @@ -1381,9 +1375,6 @@ AbstractBasePtr InferImplTensorCopySlices(const AnalysisEnginePtr &, const Primi constexpr auto kTensorCopySlicesInputNum = 5; CheckArgsSize(op_name, args_spec_list, kTensorCopySlicesInputNum); AbstractTensorPtr input = CheckArg(op_name, args_spec_list, 0); - for (size_t i = 2; i < args_spec_list.size(); ++i) { - SetSequenceElementsUseFlags(args_spec_list[i], true); - } return std::make_shared(input->element(), input->shape()); } } // namespace abstract diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 29f05fad0c2..466d73f5fb2 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -191,9 +191,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p return args_spec_list[0]; } - // For F.depend(x, MakeTuple()) or F.depend(x, tuple), set all used flags of tuple as true. - SetSequenceElementsUseFlags(dependant_abstract, true); - auto depends = args_spec_list[0]->Broaden(); // Avoid eliminating the dependent node. if (!MsContext::GetInstance()->get_param(MS_CTX_GRAD_FOR_SCALAR)) { // For scalar, need to set value to kAnyValue, because broaden scalar will not change the value. @@ -210,12 +207,6 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at least 1, but got 0"; } MS_EXCEPTION_IF_NULL(args_spec_list[0]); - - // For UpdateState(x, MakeTuple()) or UpdateState(x, tuple), set all used flags of tuple as true. - for (size_t i = 1; i < args_spec_list.size(); i++) { - SetSequenceElementsUseFlags(args_spec_list[i], true); - } - return args_spec_list[0]->Broaden(); } @@ -279,7 +270,6 @@ AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const Primitiv ret->set_indices(indices); ret->set_values(values); ret->set_dense_shape(dense_shape); - SetSequenceElementsUseFlags(dense_shape, true); return ret; } @@ -383,7 +373,6 @@ AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const Primi ret->set_indices(indices); ret->set_values(values); ret->set_dense_shape(dense_shape); - SetSequenceElementsUseFlags(dense_shape, true); return ret; } @@ -597,7 +586,6 @@ AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const Primitiv ret->set_indices(indices); ret->set_values(values); ret->set_dense_shape(shape); - SetSequenceElementsUseFlags(shape, true); return ret; } diff --git a/mindspore/core/abstract/prim_statement.cc b/mindspore/core/abstract/prim_statement.cc index a1d15ea0998..7ff9ae03ca9 100644 --- a/mindspore/core/abstract/prim_statement.cc +++ b/mindspore/core/abstract/prim_statement.cc @@ -99,9 +99,6 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP } } - // For F.SwitchLayer(MakeTuple(), x) or F.depend(tuple, x), set all used flags of tuple as true. - SetSequenceElementsUseFlags(branches_abs, true); - auto b = branches[0]; // Return AbstractFuncUnion, otherwise the switch_layer will be replaced by branches[0] // which will cancel the out of bound checking for index diff --git a/mindspore/core/abstract/prim_structures.cc b/mindspore/core/abstract/prim_structures.cc index 6d343072c9f..6c84485ed58 100644 --- a/mindspore/core/abstract/prim_structures.cc +++ b/mindspore/core/abstract/prim_structures.cc @@ -60,8 +60,6 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr auto key_string = GetValue(keyPtr); (void)key_value.emplace_back(key_string, value_list[index]); } - SetSequenceElementsUseFlags(keys, true); - SetSequenceElementsUseFlags(values, true); return std::make_shared(key_value); } @@ -167,7 +165,6 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra constexpr int target_value_index = 2; elements[index_unsigned_value] = args_spec_list[target_value_index]; MS_LOG(DEBUG) << "SetItem use flags, index: " << index_unsigned_value << ", for " << queue->ToString(); - SetSequenceElementsUseFlags(queue, index_unsigned_value, true); return std::make_shared(elements, queue->sequence_nodes()); } diff --git a/mindspore/core/ops/accumulate_n_v2.cc b/mindspore/core/ops/accumulate_n_v2.cc index 620d9f7af40..77ec8f2a7f5 100644 --- a/mindspore/core/ops/accumulate_n_v2.cc +++ b/mindspore/core/ops/accumulate_n_v2.cc @@ -96,7 +96,6 @@ AbstractBasePtr AccumulateNV2Infer(const abstract::AnalysisEnginePtr &, const Pr for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - SetSequenceElementsUseFlags(input_args[0], true); auto infer_type = AccumulateNV2InferType(primitive, input_args); auto infer_shape = AccumulateNV2InferShape(primitive, input_args); return abstract::MakeAbstract(infer_shape, infer_type); diff --git a/mindspore/core/ops/addn.cc b/mindspore/core/ops/addn.cc index b44e665a09a..d1e11868890 100644 --- a/mindspore/core/ops/addn.cc +++ b/mindspore/core/ops/addn.cc @@ -98,10 +98,7 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - auto res = abstract::MakeAbstract(AddNInferShape(primitive, input_args), AddNInferType(primitive, input_args)); - // For AddN(MakeTuple()) or AddN(tuple), set all used flags of tuple as true. - SetSequenceElementsUseFlags(input_args[0], true); - return res; + return abstract::MakeAbstract(AddNInferShape(primitive, input_args), AddNInferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(AddN, prim::kPrimAddN, AddNInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/ctc_loss_v2.cc b/mindspore/core/ops/ctc_loss_v2.cc index 32612d587a9..89985a66403 100644 --- a/mindspore/core/ops/ctc_loss_v2.cc +++ b/mindspore/core/ops/ctc_loss_v2.cc @@ -39,7 +39,6 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); - SetSequenceElementsUseFlags(item, true); } auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); auto targets_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape()); diff --git a/mindspore/core/ops/ctc_loss_v2_grad.cc b/mindspore/core/ops/ctc_loss_v2_grad.cc index 91364f17341..66f58d6ad19 100644 --- a/mindspore/core/ops/ctc_loss_v2_grad.cc +++ b/mindspore/core/ops/ctc_loss_v2_grad.cc @@ -37,7 +37,6 @@ abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); - SetSequenceElementsUseFlags(item, true); } auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape()); auto log_probs_shape = log_probs_shape_map[kShape]; diff --git a/mindspore/core/ops/dynamic_resize_nearest_neighbor.cc b/mindspore/core/ops/dynamic_resize_nearest_neighbor.cc index c2b5790cc90..71bbd31d962 100644 --- a/mindspore/core/ops/dynamic_resize_nearest_neighbor.cc +++ b/mindspore/core/ops/dynamic_resize_nearest_neighbor.cc @@ -118,14 +118,11 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } // namespace AbstractBasePtr DynamicResizeNearestNeighborInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - constexpr size_t size_index = 1; auto prim_name = primitive->name(); const int64_t input_num = 2; (void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(input_args)), kEqual, input_num, prim_name); auto res = abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); - // Set all used flags of tuple as true. - SetSequenceElementsUseFlags(input_args[size_index], true); return res; } REGISTER_PRIMITIVE_EVAL_IMPL(DynamicResizeNearestNeighbor, prim::kPrimDynamicResizeNearestNeighbor, diff --git a/mindspore/core/ops/einsum.cc b/mindspore/core/ops/einsum.cc index 011e0b8d95b..c1cc4e6e9cd 100644 --- a/mindspore/core/ops/einsum.cc +++ b/mindspore/core/ops/einsum.cc @@ -240,8 +240,6 @@ AbstractBasePtr EinsumInfer(const abstract::AnalysisEnginePtr &, const Primitive CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); auto res = std::make_shared(InferType(primitive, input_args), InferShape(primitive, input_args)); - // For Einsum(tuple/list), set all used flags of tuple as true. - SetSequenceElementsUseFlags(input_args[0], true); return res; } REGISTER_PRIMITIVE_EVAL_IMPL(Einsum, prim::kPrimEinsum, EinsumInfer, nullptr, true); diff --git a/mindspore/core/ops/grad/avg_pool_3d_grad.cc b/mindspore/core/ops/grad/avg_pool_3d_grad.cc index 82c79e00ed4..3f65b95765c 100644 --- a/mindspore/core/ops/grad/avg_pool_3d_grad.cc +++ b/mindspore/core/ops/grad/avg_pool_3d_grad.cc @@ -60,11 +60,8 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) { - constexpr size_t origin_input_size_index = 0; auto res = std::make_shared(InferType(primitive, input_args), InferShape(primitive, input_args)->shape()); - // Set all used flags of tuple as true. - SetSequenceElementsUseFlags(input_args[origin_input_size_index], true); return res; } diff --git a/mindspore/core/ops/grad/conv2d_backprop_filter.cc b/mindspore/core/ops/grad/conv2d_backprop_filter.cc index 8940076eb1a..c956f7aa330 100644 --- a/mindspore/core/ops/grad/conv2d_backprop_filter.cc +++ b/mindspore/core/ops/grad/conv2d_backprop_filter.cc @@ -241,8 +241,6 @@ AbstractBasePtr Conv2DBackpropFilterInfer(const abstract::AnalysisEnginePtr &, c } auto res = std::make_shared(Conv2DBackpropFilterInferType(primitive, input_args), Conv2DBackpropFilterInferShape(primitive, input_args)); - // Set all used flags of tuple as true. - SetSequenceElementsUseFlags(input_args[kFilterSizeIndex], true); return res; } REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropFilter, prim::kPrimConv2DBackpropFilter, Conv2DBackpropFilterInfer, nullptr, diff --git a/mindspore/core/ops/grad/dropout_grad.cc b/mindspore/core/ops/grad/dropout_grad.cc index 789f18127a7..f21f6acfa2e 100644 --- a/mindspore/core/ops/grad/dropout_grad.cc +++ b/mindspore/core/ops/grad/dropout_grad.cc @@ -40,7 +40,6 @@ float DropoutGrad::get_keep_prob() const { AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - constexpr size_t shape_index = 0; auto op_name = primitive->name(); const int64_t input_num = 2; CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name); @@ -55,8 +54,6 @@ AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const Prim auto out_type = CheckAndConvertUtils::CheckTensorTypeValid("x", dy_type, valid_types, op_name); auto shape = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, dy_index); auto res = abstract::MakeAbstract(shape, out_type); - // Set all used flags of tuple as true. - SetSequenceElementsUseFlags(input_args[shape_index], true); return res; } REGISTER_PRIMITIVE_EVAL_IMPL(DropoutGrad, prim::kPrimDropoutGrad, DropoutGradInfer, nullptr, true); diff --git a/mindspore/core/ops/grad/strided_slice_grad.cc b/mindspore/core/ops/grad/strided_slice_grad.cc index aa7fd2e6f75..c89f7d609ec 100644 --- a/mindspore/core/ops/grad/strided_slice_grad.cc +++ b/mindspore/core/ops/grad/strided_slice_grad.cc @@ -118,19 +118,10 @@ TypePtr StridedSliceGradInferType(const PrimitivePtr &primitive, const std::vect AbstractBasePtr StridedSliceGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - constexpr size_t shape_index = 1; - constexpr size_t begin_index = 2; - constexpr size_t end_index = 3; - constexpr size_t stride_index = 4; constexpr int64_t input_num = 5; CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); auto res = abstract::MakeAbstract(StridedSliceGradInferShape(primitive, input_args), StridedSliceGradInferType(primitive, input_args)); - // Set all used flags of tuple as true. - SetSequenceElementsUseFlags(input_args[shape_index], true); - SetSequenceElementsUseFlags(input_args[begin_index], true); - SetSequenceElementsUseFlags(input_args[end_index], true); - SetSequenceElementsUseFlags(input_args[stride_index], true); return res; } diff --git a/mindspore/core/ops/lp_norm.cc b/mindspore/core/ops/lp_norm.cc index d54745c7cdd..20a93f8ceef 100644 --- a/mindspore/core/ops/lp_norm.cc +++ b/mindspore/core/ops/lp_norm.cc @@ -105,7 +105,6 @@ AbstractBasePtr LpNormInfer(const abstract::AnalysisEnginePtr &, const Primitive CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name()); auto infer_type = InferType(primitive, input_args); auto infer_shape = InferShape(primitive, input_args); - SetSequenceElementsUseFlags(input_args[0], true); return abstract::MakeAbstract(infer_shape, infer_type); } REGISTER_PRIMITIVE_EVAL_IMPL(LpNorm, prim::kPrimLpNorm, LpNormInfer, nullptr, true); diff --git a/mindspore/core/ops/neighborexchange.cc b/mindspore/core/ops/neighborexchange.cc index 7e8154cdd83..edae7a7c9bb 100644 --- a/mindspore/core/ops/neighborexchange.cc +++ b/mindspore/core/ops/neighborexchange.cc @@ -182,16 +182,10 @@ TypePtr InferType(const PrimitivePtr &primitive) { } // namespace AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - constexpr size_t input_tuple_index = 0; Check(primitive, input_args); auto type = InferType(primitive); auto shape = InferShape(primitive); - auto res = abstract::MakeAbstract(shape, type); - // Set all used flags of tuple as true. - if (!input_args.empty()) { - SetSequenceElementsUseFlags(input_args[input_tuple_index], true); - } - return res; + return abstract::MakeAbstract(shape, type); } REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/ones.cc b/mindspore/core/ops/ones.cc index 5db00cf9fc2..a40183331fb 100644 --- a/mindspore/core/ops/ones.cc +++ b/mindspore/core/ops/ones.cc @@ -57,14 +57,10 @@ TypePtr OnesInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - constexpr size_t shape_index = 0; const std::string op_name = primitive->name(); const int64_t input_num = 2; CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name); - auto res = abstract::MakeAbstract(OnesInferShape(primitive, input_args), OnesInferType(primitive, input_args)); - // Set all used flags of tuple as true. - SetSequenceElementsUseFlags(input_args[shape_index], true); - return res; + return abstract::MakeAbstract(OnesInferShape(primitive, input_args), OnesInferType(primitive, input_args)); } ValuePtr OnesInferValue(const PrimitivePtr &prim, const std::vector &input_args) { diff --git a/mindspore/core/ops/transpose.cc b/mindspore/core/ops/transpose.cc index 7d4cb7e1abd..4e852918bd3 100644 --- a/mindspore/core/ops/transpose.cc +++ b/mindspore/core/ops/transpose.cc @@ -93,20 +93,13 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - constexpr size_t input_perm_index = 1; // The second input is optional. constexpr size_t input_size1 = 1; - constexpr size_t input_size2 = 2; (void)CheckAndConvertUtils::CheckInteger("Transpose infer", SizeToLong(input_args.size()), kGreaterEqual, input_size1, primitive->name()); auto type = InferType(primitive, input_args); auto shape = InferShape(primitive, input_args); - auto res = abstract::MakeAbstract(shape, type); - if (input_args.size() == input_size2) { - SetSequenceElementsUseFlags(input_args[input_perm_index], true); - } - - return res; + return abstract::MakeAbstract(shape, type); } REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/zeros.cc b/mindspore/core/ops/zeros.cc index 09e50459f49..9382c781907 100644 --- a/mindspore/core/ops/zeros.cc +++ b/mindspore/core/ops/zeros.cc @@ -59,11 +59,7 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - constexpr size_t shape_index = 0; - auto res = abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args)); - // Set all used flags of tuple as true. - SetSequenceElementsUseFlags(input_args[shape_index], true); - return res; + return abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args)); } ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector &input_args) {