!65 fix bug of global variable initialization order is not guaranteed during compilation

Merge pull request !65 from mxm/for_prim_map_init
This commit is contained in:
mindspore-ci-bot 2020-04-09 10:19:50 +08:00 committed by Gitee
commit 086edc133f
2 changed files with 148 additions and 149 deletions

View File

@ -42,92 +42,95 @@
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
PrimitiveEvalImplMap PrimitiveToInferImplMap = { PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
// Statements static PrimitiveEvalImplMap prim_eval_implement_map = {
{prim::kPrimReturn, {InferImplReturn, true}}, // Statements
{prim::kPrimTypeOf, {InferImplTypeof, false}}, {prim::kPrimReturn, {InferImplReturn, true}},
{prim::kPrimHasType, {InferImplHasType, false}}, {prim::kPrimTypeOf, {InferImplTypeof, false}},
{prim::kPrimDot, {InferImplDot, true}}, {prim::kPrimHasType, {InferImplHasType, false}},
{prim::kPrimSwitch, {InferImplSwitch, true}}, {prim::kPrimDot, {InferImplDot, true}},
{prim::kPrimIs_, {InferImplIs_, true}}, {prim::kPrimSwitch, {InferImplSwitch, true}},
{prim::kPrimIsNot, {InferImplIsNot, true}}, {prim::kPrimIs_, {InferImplIs_, true}},
// Maths {prim::kPrimIsNot, {InferImplIsNot, true}},
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, // Maths
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
// Array {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, // Array
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, {prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
{prim::kPrimShape, {InferImplShape, true}}, {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
{prim::kPrimPack, {InferImplPack, true}}, {prim::kPrimShape, {InferImplShape, true}},
// Structure {prim::kPrimPack, {InferImplPack, true}},
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, // Structure
{prim::kPrimMakeList, {InferImplMakeList, true}}, {prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeDict, {InferImplMakeDict, true}}, {prim::kPrimMakeList, {InferImplMakeList, true}},
{prim::kPrimMakeSlice, {InferImplMakeSlice, true}}, {prim::kPrimMakeDict, {InferImplMakeDict, true}},
{prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}}, {prim::kPrimMakeSlice, {InferImplMakeSlice, true}},
{prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}}, {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}},
{prim::kPrimMakeRecord, {InferImplMakeRecord, false}}, {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}},
{prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}}, {prim::kPrimMakeRecord, {InferImplMakeRecord, false}},
{prim::kPrimListGetItem, {InferImplListGetItem, true}}, {prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}},
{prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}}, {prim::kPrimListGetItem, {InferImplListGetItem, true}},
{prim::kPrimListSetItem, {InferImplListSetItem, true}}, {prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}},
{prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, {prim::kPrimListSetItem, {InferImplListSetItem, true}},
{prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, {prim::kPrimDictGetItem, {InferImplDictGetItem, true}},
{prim::kPrimListAppend, {InferImplListAppend, true}}, {prim::kPrimDictSetItem, {InferImplDictSetItem, true}},
{prim::kPrimTupleLen, {InferImplTupleLen, true}}, {prim::kPrimListAppend, {InferImplListAppend, true}},
{prim::kPrimListLen, {InferImplListLen, true}}, {prim::kPrimTupleLen, {InferImplTupleLen, true}},
{prim::kPrimArrayLen, {InferImplArrayLen, true}}, {prim::kPrimListLen, {InferImplListLen, true}},
{prim::kPrimListMap, {InferImplListMap, false}}, {prim::kPrimArrayLen, {InferImplArrayLen, true}},
{prim::kPrimListReduce, {InferImplListReduce, false}}, {prim::kPrimListMap, {InferImplListMap, false}},
{prim::kPrimTupleReversed, {InferImplTupleReversed, false}}, {prim::kPrimListReduce, {InferImplListReduce, false}},
{prim::kPrimReducedShape, {InferImplReduceShape, false}}, {prim::kPrimTupleReversed, {InferImplTupleReversed, false}},
{prim::kPrimTupleDiv, {InferImplTupleDiv, false}}, {prim::kPrimReducedShape, {InferImplReduceShape, false}},
{prim::kPrimTupleToArray, {InferImplTuple2Array, false}}, {prim::kPrimTupleDiv, {InferImplTupleDiv, false}},
{prim::kPrimShapeMul, {InferImplShapeMul, false}}, {prim::kPrimTupleToArray, {InferImplTuple2Array, false}},
{prim::kPrimTupleEqual, {InferImplTupleEqual, false}}, {prim::kPrimShapeMul, {InferImplShapeMul, false}},
{prim::kPrimListEqual, {InferImplListEqual, false}}, {prim::kPrimTupleEqual, {InferImplTupleEqual, false}},
{prim::kPrimMakeRange, {InferImplMakeRange, false}}, {prim::kPrimListEqual, {InferImplListEqual, false}},
{prim::kPrimStopGradient, {InferImplStopGradient, false}}, {prim::kPrimMakeRange, {InferImplMakeRange, false}},
{prim::kPrimStringEqual, {InferImplStringEqual, false}}, {prim::kPrimStopGradient, {InferImplStopGradient, false}},
{prim::kPrimDictLen, {InferImplDictLen, false}}, {prim::kPrimStringEqual, {InferImplStringEqual, false}},
// NN {prim::kPrimDictLen, {InferImplDictLen, false}},
{prim::kPrimPooling, {InferImplPooling, true}}, // NN
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, {prim::kPrimPooling, {InferImplPooling, true}},
{prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
{prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}},
{prim::kPrimReluGrad, {InferImplReluGrad, true}}, {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
{prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, {prim::kPrimReluGrad, {InferImplReluGrad, true}},
{prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}},
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}},
{prim::kPrimRelu, {InferImplRelu, true}}, {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},
{prim::kPrimZerosLikeTensor, {InferImplZerosLikeTensor, true}}, {prim::kPrimRelu, {InferImplRelu, true}},
{prim::kPrimFakeBprop, {InferImplFakeBprop, false}}, {prim::kPrimZerosLikeTensor, {InferImplZerosLikeTensor, true}},
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, {prim::kPrimFakeBprop, {InferImplFakeBprop, false}},
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, {prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
// Others {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
{prim::kPrimIdentity, {InferImplIdentity, true}}, // Others
// Set impl to null as it will use PartialEvaluator; {prim::kPrimIdentity, {InferImplIdentity, true}},
{prim::kPrimPartial, {nullptr, true}}, // Set impl to null as it will use PartialEvaluator;
{prim::kPrimJ, {InferImplJ, false}}, {prim::kPrimPartial, {nullptr, true}},
{prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}}, {prim::kPrimJ, {InferImplJ, false}},
{prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}}, {prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}},
{prim::kPrimEnvAdd, {InferImplEnvAdd, true}}, {prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}},
{prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}}, {prim::kPrimEnvAdd, {InferImplEnvAdd, true}},
{prim::kPrimMakeRef, {InferImplMakeRef, true}}, {prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}},
{prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, {prim::kPrimMakeRef, {InferImplMakeRef, true}},
{prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, {prim::kPrimGetRefKey, {InferImplGetRefKey, true}},
{prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}}, {prim::kPrimGetRefValue, {InferImplGetRefValue, true}},
{prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, {prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}},
{prim::kPrimDepend, {InferImplDepend, true}}, {prim::kPrimStateSetItem, {InferImplStateSetItem, true}},
{prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, {prim::kPrimDepend, {InferImplDepend, true}},
{prim::kPrimControlDepend, {InferImplControlDepend, true}}, {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}},
// Debug {prim::kPrimControlDepend, {InferImplControlDepend, true}},
{prim::kPrimScalarSummary, {InferImplScalarSummary, true}}, // Debug
{prim::kPrimImageSummary, {InferImplTensorSummary, true}}, {prim::kPrimScalarSummary, {InferImplScalarSummary, true}},
{prim::kPrimTensorSummary, {InferImplTensorSummary, true}}, {prim::kPrimImageSummary, {InferImplTensorSummary, true}},
}; {prim::kPrimTensorSummary, {InferImplTensorSummary, true}},
};
return prim_eval_implement_map;
}
using mindspore::parse::PyObjectWrapper; using mindspore::parse::PyObjectWrapper;
@ -961,10 +964,7 @@ class PartialEvaluator : public Evaluator {
new_nodes_inputs[1] = NewValueNode(new_signature_value); new_nodes_inputs[1] = NewValueNode(new_signature_value);
FuncGraphPtr func_graph = cnode->func_graph(); FuncGraphPtr func_graph = cnode->func_graph();
ScopePtr scope = kDefaultScope; ScopePtr scope = out_conf->node()->scope();
if (out_conf != nullptr) {
scope = out_conf->node()->scope();
}
ScopeGuard scope_guard(scope); ScopeGuard scope_guard(scope);
CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs); CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs);
@ -981,39 +981,41 @@ struct PrimitiveImplInferValue {
}; };
using PrimitiveToImplMap = std::unordered_map<PrimitivePtr, PrimitiveImplInferValue, PrimitiveHasher, PrimitiveEqual>; using PrimitiveToImplMap = std::unordered_map<PrimitivePtr, PrimitiveImplInferValue, PrimitiveHasher, PrimitiveEqual>;
PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
PrimitiveToImplMap UniformPrimitiveToImplMapValue = { static PrimitiveToImplMap uniform_prim_implement_map = {
{prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}}, {prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}},
{prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}}, {prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}},
{prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}}, {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}},
{prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}}, {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}},
{prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}}, {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}},
{prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}}, {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}},
{prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}}, {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}},
{prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}}, {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}},
{prim::kPrimScalarEq, {prim::ScalarEq, true, std::make_shared<Bool>(), true}}, {prim::kPrimScalarEq, {prim::ScalarEq, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarLt, {prim::ScalarLt, true, std::make_shared<Bool>(), true}}, {prim::kPrimScalarLt, {prim::ScalarLt, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarGt, {prim::ScalarGt, true, std::make_shared<Bool>(), true}}, {prim::kPrimScalarGt, {prim::ScalarGt, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarNe, {prim::ScalarNe, true, std::make_shared<Bool>(), true}}, {prim::kPrimScalarNe, {prim::ScalarNe, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarLe, {prim::ScalarLe, true, std::make_shared<Bool>(), true}}, {prim::kPrimScalarLe, {prim::ScalarLe, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarGe, {prim::ScalarGe, true, std::make_shared<Bool>(), true}}, {prim::kPrimScalarGe, {prim::ScalarGe, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolNot, {prim::BoolNot, true, std::make_shared<Bool>(), true}}, {prim::kPrimBoolNot, {prim::BoolNot, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared<Bool>(), true}}, {prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared<Bool>(), true}}, {prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared<Bool>(), true}}, {prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared<Bool>(), true}},
}; };
return uniform_prim_implement_map;
}
PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap(); PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap();
std::mutex PrimEvaluatorConstructorMutex; std::mutex PrimEvaluatorConstructorMutex;
void InitPrimEvaluatorConstructors(const PrimitiveEvalImplMap &prim_eval_impl_map) { void InitPrimEvaluatorConstructors() {
PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
for (const auto &iter : prim_eval_impl_map) { for (const auto &iter : GetPrimitiveToEvalImplMap()) {
constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_); constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_);
} }
for (const auto &iter : UniformPrimitiveToImplMapValue) { for (const auto &iter : GetUniformPrimitiveToImplMap()) {
constructor[iter.first] = constructor[iter.first] =
InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_); InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_);
} }
@ -1028,20 +1030,20 @@ void InitPrimEvaluatorConstructors(const PrimitiveEvalImplMap &prim_eval_impl_ma
void ClearPrimEvaluatorMap() { void ClearPrimEvaluatorMap() {
PrimEvaluatorConstructors.clear(); PrimEvaluatorConstructors.clear();
PrimitiveToInferImplMap.clear(); GetPrimitiveToEvalImplMap().clear();
UniformPrimitiveToImplMapValue.clear(); GetUniformPrimitiveToImplMap().clear();
} }
bool IsInWhiteList(const PrimitivePtr primitive) { bool IsInWhiteList(const PrimitivePtr primitive) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto iter = PrimitiveToInferImplMap.find(primitive); auto iter = GetPrimitiveToEvalImplMap().find(primitive);
if (iter != PrimitiveToInferImplMap.end()) { if (iter != GetPrimitiveToEvalImplMap().end()) {
return iter->second.in_white_list_; return iter->second.in_white_list_;
} }
auto uni_iter = UniformPrimitiveToImplMapValue.find(primitive); auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive);
if (uni_iter != UniformPrimitiveToImplMapValue.end()) { if (uni_iter != GetUniformPrimitiveToImplMap().end()) {
return uni_iter->second.in_white_list_; return uni_iter->second.in_white_list_;
} }
@ -1050,8 +1052,8 @@ bool IsInWhiteList(const PrimitivePtr primitive) {
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) { StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto iter = PrimitiveToInferImplMap.find(primitive); auto iter = GetPrimitiveToEvalImplMap().find(primitive);
if (iter == PrimitiveToInferImplMap.end()) { if (iter == GetPrimitiveToEvalImplMap().end()) {
return nullptr; return nullptr;
} }
return iter->second.impl_; return iter->second.impl_;
@ -1064,7 +1066,7 @@ PrimEvaluatorMap &GetPrimEvaluatorConstructors() {
} }
std::lock_guard<std::mutex> initLock(PrimEvaluatorConstructorMutex); std::lock_guard<std::mutex> initLock(PrimEvaluatorConstructorMutex);
if (constructor.empty()) { if (constructor.empty()) {
InitPrimEvaluatorConstructors(PrimitiveToInferImplMap); InitPrimEvaluatorConstructors();
} }
return constructor; return constructor;

View File

@ -296,38 +296,35 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
if (prim->HasPyEvaluator()) { if (prim->HasPyEvaluator()) {
auto prim_py = dyn_cast<PrimitivePy>(prim); auto prim_py = dyn_cast<PrimitivePy>(prim);
if (prim_py != nullptr) { if (prim_py != nullptr) {
evaluator = std::make_shared<PythonPrimEvaluator>(prim_py); return std::make_shared<PythonPrimEvaluator>(prim_py);
} else { }
MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive.";
}
if (prim->isa<PrimitivePy>() || prim->HasAttr()) {
if (engine == nullptr) {
(void)GetPrimEvaluatorConstructors();
} }
} else if (prim->isa<PrimitivePy>() || prim->HasAttr()) {
// If a primitive may have attr, try to create a new evaluator. // If a primitive may have attr, try to create a new evaluator.
StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim); StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim);
if (eval_impl != nullptr) { if (eval_impl != nullptr) {
std::shared_ptr<StandardPrimEvaluator> standard_evaluator = return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
evaluator = standard_evaluator;
} }
} }
if (evaluator == nullptr) {
if (engine == nullptr) { if (engine == nullptr) {
// If engine is nullptr, get constructor from default. // If engine is nullptr, get constructor from default.
const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors(); const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors();
auto iter = prim_evaluator_map.find(prim); auto iter = prim_evaluator_map.find(prim);
if (iter == prim_evaluator_map.end()) { if (iter != prim_evaluator_map.end()) {
evaluator = nullptr; evaluator = iter->second;
} else { }
evaluator = iter->second; } else {
} // If engine is given, get constructor from engine resource.
} else { const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors();
// If engine is given, get constructor from engine resource. auto iter = prim_evaluator_map.find(prim);
const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors(); if (iter != prim_evaluator_map.end()) {
auto iter = prim_evaluator_map.find(prim); evaluator = iter->second;
if (iter == prim_evaluator_map.end()) {
evaluator = nullptr;
} else {
evaluator = iter->second;
}
} }
} }
if (evaluator == nullptr) { if (evaluator == nullptr) {