From 2f031e154033f3dc82c03fbf084c1ba3ce7bc127 Mon Sep 17 00:00:00 2001 From: mxm <83028974@qq.com> Date: Tue, 31 Mar 2020 12:08:02 +0800 Subject: [PATCH] fixed: PrimitiveToInferImplMap map is global, and key of the map PrimitivePtr also a global variable. If key is initialized later than the map initialized during compilation, will cause the primitive map initialize failed. Variable initialization order is not guaranteed during compilation. --- .../ccsrc/pipeline/static_analysis/prim.cc | 248 +++++++++--------- .../static_analysis/static_analysis.cc | 49 ++-- 2 files changed, 148 insertions(+), 149 deletions(-) diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index 98d82de5d59..e06d58466d7 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -42,92 +42,95 @@ namespace mindspore { namespace abstract { -PrimitiveEvalImplMap PrimitiveToInferImplMap = { - // Statements - {prim::kPrimReturn, {InferImplReturn, true}}, - {prim::kPrimTypeOf, {InferImplTypeof, false}}, - {prim::kPrimHasType, {InferImplHasType, false}}, - {prim::kPrimDot, {InferImplDot, true}}, - {prim::kPrimSwitch, {InferImplSwitch, true}}, - {prim::kPrimIs_, {InferImplIs_, true}}, - {prim::kPrimIsNot, {InferImplIsNot, true}}, - // Maths - {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, - {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, - // Array - {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, - {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, - {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, - {prim::kPrimShape, {InferImplShape, true}}, - {prim::kPrimPack, {InferImplPack, true}}, - // Structure - {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, - {prim::kPrimMakeList, {InferImplMakeList, true}}, - {prim::kPrimMakeDict, {InferImplMakeDict, true}}, - {prim::kPrimMakeSlice, {InferImplMakeSlice, true}}, - {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}}, - {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}}, - {prim::kPrimMakeRecord, {InferImplMakeRecord, false}}, - {prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}}, - {prim::kPrimListGetItem, {InferImplListGetItem, true}}, - {prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}}, - {prim::kPrimListSetItem, {InferImplListSetItem, true}}, - {prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, - {prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, - {prim::kPrimListAppend, {InferImplListAppend, true}}, - {prim::kPrimTupleLen, {InferImplTupleLen, true}}, - {prim::kPrimListLen, {InferImplListLen, true}}, - {prim::kPrimArrayLen, {InferImplArrayLen, true}}, - {prim::kPrimListMap, {InferImplListMap, false}}, - {prim::kPrimListReduce, {InferImplListReduce, false}}, - {prim::kPrimTupleReversed, {InferImplTupleReversed, false}}, - {prim::kPrimReducedShape, {InferImplReduceShape, false}}, - {prim::kPrimTupleDiv, {InferImplTupleDiv, false}}, - {prim::kPrimTupleToArray, {InferImplTuple2Array, false}}, - {prim::kPrimShapeMul, {InferImplShapeMul, false}}, - {prim::kPrimTupleEqual, {InferImplTupleEqual, false}}, - {prim::kPrimListEqual, {InferImplListEqual, false}}, - {prim::kPrimMakeRange, {InferImplMakeRange, false}}, - {prim::kPrimStopGradient, {InferImplStopGradient, false}}, - {prim::kPrimStringEqual, {InferImplStringEqual, false}}, - {prim::kPrimDictLen, {InferImplDictLen, false}}, - // NN - {prim::kPrimPooling, {InferImplPooling, true}}, - {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, - {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, - {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, - {prim::kPrimReluGrad, {InferImplReluGrad, true}}, - {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, - {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, - {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, - {prim::kPrimRelu, {InferImplRelu, true}}, - {prim::kPrimZerosLikeTensor, {InferImplZerosLikeTensor, true}}, - {prim::kPrimFakeBprop, {InferImplFakeBprop, false}}, - {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, - {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, - {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, - // Others - {prim::kPrimIdentity, {InferImplIdentity, true}}, - // Set impl to null as it will use PartialEvaluator; - {prim::kPrimPartial, {nullptr, true}}, - {prim::kPrimJ, {InferImplJ, false}}, - {prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}}, - {prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}}, - {prim::kPrimEnvAdd, {InferImplEnvAdd, true}}, - {prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}}, - {prim::kPrimMakeRef, {InferImplMakeRef, true}}, - {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, - {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, - {prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}}, - {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, - {prim::kPrimDepend, {InferImplDepend, true}}, - {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, - {prim::kPrimControlDepend, {InferImplControlDepend, true}}, - // Debug - {prim::kPrimScalarSummary, {InferImplScalarSummary, true}}, - {prim::kPrimImageSummary, {InferImplTensorSummary, true}}, - {prim::kPrimTensorSummary, {InferImplTensorSummary, true}}, -}; +PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { + static PrimitiveEvalImplMap prim_eval_implement_map = { + // Statements + {prim::kPrimReturn, {InferImplReturn, true}}, + {prim::kPrimTypeOf, {InferImplTypeof, false}}, + {prim::kPrimHasType, {InferImplHasType, false}}, + {prim::kPrimDot, {InferImplDot, true}}, + {prim::kPrimSwitch, {InferImplSwitch, true}}, + {prim::kPrimIs_, {InferImplIs_, true}}, + {prim::kPrimIsNot, {InferImplIsNot, true}}, + // Maths + {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, + {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, + // Array + {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, + {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, + {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, + {prim::kPrimShape, {InferImplShape, true}}, + {prim::kPrimPack, {InferImplPack, true}}, + // Structure + {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, + {prim::kPrimMakeList, {InferImplMakeList, true}}, + {prim::kPrimMakeDict, {InferImplMakeDict, true}}, + {prim::kPrimMakeSlice, {InferImplMakeSlice, true}}, + {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}}, + {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}}, + {prim::kPrimMakeRecord, {InferImplMakeRecord, false}}, + {prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}}, + {prim::kPrimListGetItem, {InferImplListGetItem, true}}, + {prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}}, + {prim::kPrimListSetItem, {InferImplListSetItem, true}}, + {prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, + {prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, + {prim::kPrimListAppend, {InferImplListAppend, true}}, + {prim::kPrimTupleLen, {InferImplTupleLen, true}}, + {prim::kPrimListLen, {InferImplListLen, true}}, + {prim::kPrimArrayLen, {InferImplArrayLen, true}}, + {prim::kPrimListMap, {InferImplListMap, false}}, + {prim::kPrimListReduce, {InferImplListReduce, false}}, + {prim::kPrimTupleReversed, {InferImplTupleReversed, false}}, + {prim::kPrimReducedShape, {InferImplReduceShape, false}}, + {prim::kPrimTupleDiv, {InferImplTupleDiv, false}}, + {prim::kPrimTupleToArray, {InferImplTuple2Array, false}}, + {prim::kPrimShapeMul, {InferImplShapeMul, false}}, + {prim::kPrimTupleEqual, {InferImplTupleEqual, false}}, + {prim::kPrimListEqual, {InferImplListEqual, false}}, + {prim::kPrimMakeRange, {InferImplMakeRange, false}}, + {prim::kPrimStopGradient, {InferImplStopGradient, false}}, + {prim::kPrimStringEqual, {InferImplStringEqual, false}}, + {prim::kPrimDictLen, {InferImplDictLen, false}}, + // NN + {prim::kPrimPooling, {InferImplPooling, true}}, + {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, + {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, + {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, + {prim::kPrimReluGrad, {InferImplReluGrad, true}}, + {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, + {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, + {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, + {prim::kPrimRelu, {InferImplRelu, true}}, + {prim::kPrimZerosLikeTensor, {InferImplZerosLikeTensor, true}}, + {prim::kPrimFakeBprop, {InferImplFakeBprop, false}}, + {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, + {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, + {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, + // Others + {prim::kPrimIdentity, {InferImplIdentity, true}}, + // Set impl to null as it will use PartialEvaluator; + {prim::kPrimPartial, {nullptr, true}}, + {prim::kPrimJ, {InferImplJ, false}}, + {prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}}, + {prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}}, + {prim::kPrimEnvAdd, {InferImplEnvAdd, true}}, + {prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}}, + {prim::kPrimMakeRef, {InferImplMakeRef, true}}, + {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, + {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, + {prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}}, + {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, + {prim::kPrimDepend, {InferImplDepend, true}}, + {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, + {prim::kPrimControlDepend, {InferImplControlDepend, true}}, + // Debug + {prim::kPrimScalarSummary, {InferImplScalarSummary, true}}, + {prim::kPrimImageSummary, {InferImplTensorSummary, true}}, + {prim::kPrimTensorSummary, {InferImplTensorSummary, true}}, + }; + return prim_eval_implement_map; +} using mindspore::parse::PyObjectWrapper; @@ -907,10 +910,7 @@ class PartialEvaluator : public Evaluator { new_nodes_inputs[1] = NewValueNode(new_signature_value); FuncGraphPtr func_graph = cnode->func_graph(); - ScopePtr scope = kDefaultScope; - if (out_conf != nullptr) { - scope = out_conf->node()->scope(); - } + ScopePtr scope = out_conf->node()->scope(); ScopeGuard scope_guard(scope); CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs); @@ -927,39 +927,41 @@ struct PrimitiveImplInferValue { }; using PrimitiveToImplMap = std::unordered_map; - -PrimitiveToImplMap UniformPrimitiveToImplMapValue = { - {prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}}, - {prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}}, - {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}}, - {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}}, - {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}}, - {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}}, - {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}}, - {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}}, - {prim::kPrimScalarEq, {prim::ScalarEq, true, std::make_shared(), true}}, - {prim::kPrimScalarLt, {prim::ScalarLt, true, std::make_shared(), true}}, - {prim::kPrimScalarGt, {prim::ScalarGt, true, std::make_shared(), true}}, - {prim::kPrimScalarNe, {prim::ScalarNe, true, std::make_shared(), true}}, - {prim::kPrimScalarLe, {prim::ScalarLe, true, std::make_shared(), true}}, - {prim::kPrimScalarGe, {prim::ScalarGe, true, std::make_shared(), true}}, - {prim::kPrimBoolNot, {prim::BoolNot, true, std::make_shared(), true}}, - {prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared(), true}}, - {prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared(), true}}, - {prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared(), true}}, -}; +PrimitiveToImplMap &GetUniformPrimitiveToImplMap() { + static PrimitiveToImplMap uniform_prim_implement_map = { + {prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}}, + {prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}}, + {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}}, + {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}}, + {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}}, + {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}}, + {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}}, + {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}}, + {prim::kPrimScalarEq, {prim::ScalarEq, true, std::make_shared(), true}}, + {prim::kPrimScalarLt, {prim::ScalarLt, true, std::make_shared(), true}}, + {prim::kPrimScalarGt, {prim::ScalarGt, true, std::make_shared(), true}}, + {prim::kPrimScalarNe, {prim::ScalarNe, true, std::make_shared(), true}}, + {prim::kPrimScalarLe, {prim::ScalarLe, true, std::make_shared(), true}}, + {prim::kPrimScalarGe, {prim::ScalarGe, true, std::make_shared(), true}}, + {prim::kPrimBoolNot, {prim::BoolNot, true, std::make_shared(), true}}, + {prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared(), true}}, + {prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared(), true}}, + {prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared(), true}}, + }; + return uniform_prim_implement_map; +} PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap(); std::mutex PrimEvaluatorConstructorMutex; -void InitPrimEvaluatorConstructors(const PrimitiveEvalImplMap &prim_eval_impl_map) { +void InitPrimEvaluatorConstructors() { PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; - for (const auto &iter : prim_eval_impl_map) { + for (const auto &iter : GetPrimitiveToEvalImplMap()) { constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_); } - for (const auto &iter : UniformPrimitiveToImplMapValue) { + for (const auto &iter : GetUniformPrimitiveToImplMap()) { constructor[iter.first] = InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_); } @@ -974,20 +976,20 @@ void InitPrimEvaluatorConstructors(const PrimitiveEvalImplMap &prim_eval_impl_ma void ClearPrimEvaluatorMap() { PrimEvaluatorConstructors.clear(); - PrimitiveToInferImplMap.clear(); - UniformPrimitiveToImplMapValue.clear(); + GetPrimitiveToEvalImplMap().clear(); + GetUniformPrimitiveToImplMap().clear(); } bool IsInWhiteList(const PrimitivePtr primitive) { MS_EXCEPTION_IF_NULL(primitive); - auto iter = PrimitiveToInferImplMap.find(primitive); - if (iter != PrimitiveToInferImplMap.end()) { + auto iter = GetPrimitiveToEvalImplMap().find(primitive); + if (iter != GetPrimitiveToEvalImplMap().end()) { return iter->second.in_white_list_; } - auto uni_iter = UniformPrimitiveToImplMapValue.find(primitive); - if (uni_iter != UniformPrimitiveToImplMapValue.end()) { + auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive); + if (uni_iter != GetUniformPrimitiveToImplMap().end()) { return uni_iter->second.in_white_list_; } @@ -996,8 +998,8 @@ bool IsInWhiteList(const PrimitivePtr primitive) { StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) { MS_EXCEPTION_IF_NULL(primitive); - auto iter = PrimitiveToInferImplMap.find(primitive); - if (iter == PrimitiveToInferImplMap.end()) { + auto iter = GetPrimitiveToEvalImplMap().find(primitive); + if (iter == GetPrimitiveToEvalImplMap().end()) { return nullptr; } return iter->second.impl_; @@ -1010,7 +1012,7 @@ PrimEvaluatorMap &GetPrimEvaluatorConstructors() { } std::lock_guard initLock(PrimEvaluatorConstructorMutex); if (constructor.empty()) { - InitPrimEvaluatorConstructors(PrimitiveToInferImplMap); + InitPrimEvaluatorConstructors(); } return constructor; diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc index 0bfba265db0..1ac43abdd58 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc @@ -292,38 +292,35 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr if (prim->HasPyEvaluator()) { auto prim_py = dyn_cast(prim); if (prim_py != nullptr) { - evaluator = std::make_shared(prim_py); - } else { - MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; + return std::make_shared(prim_py); + } + MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; + } + + if (prim->isa() || prim->HasAttr()) { + if (engine == nullptr) { + (void)GetPrimEvaluatorConstructors(); } - } else if (prim->isa() || prim->HasAttr()) { // If a primitive may have attr, try to create a new evaluator. StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim); if (eval_impl != nullptr) { - std::shared_ptr standard_evaluator = - std::make_shared(prim, eval_impl); - evaluator = standard_evaluator; + return std::make_shared(prim, eval_impl); } } - if (evaluator == nullptr) { - if (engine == nullptr) { - // If engine is nullptr, get constructor from default. - const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors(); - auto iter = prim_evaluator_map.find(prim); - if (iter == prim_evaluator_map.end()) { - evaluator = nullptr; - } else { - evaluator = iter->second; - } - } else { - // If engine is given, get constructor from engine resource. - const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors(); - auto iter = prim_evaluator_map.find(prim); - if (iter == prim_evaluator_map.end()) { - evaluator = nullptr; - } else { - evaluator = iter->second; - } + + if (engine == nullptr) { + // If engine is nullptr, get constructor from default. + const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors(); + auto iter = prim_evaluator_map.find(prim); + if (iter != prim_evaluator_map.end()) { + evaluator = iter->second; + } + } else { + // If engine is given, get constructor from engine resource. + const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors(); + auto iter = prim_evaluator_map.find(prim); + if (iter != prim_evaluator_map.end()) { + evaluator = iter->second; } } if (evaluator == nullptr) {