fixed: PrimitiveToInferImplMap map is global, and key of the map PrimitivePtr also a global variable. If key is initialized later than the map initialized during compilation, will cause the primitive map initialize failed. Variable initialization order is not guaranteed during compilation.

This commit is contained in:
mxm 2020-03-31 12:08:02 +08:00 committed by chang zherui
parent 6d1ea7af8e
commit d375724055
2 changed files with 148 additions and 149 deletions

View File

@ -42,92 +42,95 @@
namespace mindspore {
namespace abstract {
PrimitiveEvalImplMap PrimitiveToInferImplMap = {
// Statements
{prim::kPrimReturn, {InferImplReturn, true}},
{prim::kPrimTypeOf, {InferImplTypeof, false}},
{prim::kPrimHasType, {InferImplHasType, false}},
{prim::kPrimDot, {InferImplDot, true}},
{prim::kPrimSwitch, {InferImplSwitch, true}},
{prim::kPrimIs_, {InferImplIs_, true}},
{prim::kPrimIsNot, {InferImplIsNot, true}},
// Maths
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
// Array
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
{prim::kPrimShape, {InferImplShape, true}},
{prim::kPrimPack, {InferImplPack, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeList, {InferImplMakeList, true}},
{prim::kPrimMakeDict, {InferImplMakeDict, true}},
{prim::kPrimMakeSlice, {InferImplMakeSlice, true}},
{prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}},
{prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}},
{prim::kPrimMakeRecord, {InferImplMakeRecord, false}},
{prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}},
{prim::kPrimListGetItem, {InferImplListGetItem, true}},
{prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}},
{prim::kPrimListSetItem, {InferImplListSetItem, true}},
{prim::kPrimDictGetItem, {InferImplDictGetItem, true}},
{prim::kPrimDictSetItem, {InferImplDictSetItem, true}},
{prim::kPrimListAppend, {InferImplListAppend, true}},
{prim::kPrimTupleLen, {InferImplTupleLen, true}},
{prim::kPrimListLen, {InferImplListLen, true}},
{prim::kPrimArrayLen, {InferImplArrayLen, true}},
{prim::kPrimListMap, {InferImplListMap, false}},
{prim::kPrimListReduce, {InferImplListReduce, false}},
{prim::kPrimTupleReversed, {InferImplTupleReversed, false}},
{prim::kPrimReducedShape, {InferImplReduceShape, false}},
{prim::kPrimTupleDiv, {InferImplTupleDiv, false}},
{prim::kPrimTupleToArray, {InferImplTuple2Array, false}},
{prim::kPrimShapeMul, {InferImplShapeMul, false}},
{prim::kPrimTupleEqual, {InferImplTupleEqual, false}},
{prim::kPrimListEqual, {InferImplListEqual, false}},
{prim::kPrimMakeRange, {InferImplMakeRange, false}},
{prim::kPrimStopGradient, {InferImplStopGradient, false}},
{prim::kPrimStringEqual, {InferImplStringEqual, false}},
{prim::kPrimDictLen, {InferImplDictLen, false}},
// NN
{prim::kPrimPooling, {InferImplPooling, true}},
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
{prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}},
{prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
{prim::kPrimReluGrad, {InferImplReluGrad, true}},
{prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}},
{prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}},
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},
{prim::kPrimRelu, {InferImplRelu, true}},
{prim::kPrimZerosLikeTensor, {InferImplZerosLikeTensor, true}},
{prim::kPrimFakeBprop, {InferImplFakeBprop, false}},
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
// Others
{prim::kPrimIdentity, {InferImplIdentity, true}},
// Set impl to null as it will use PartialEvaluator;
{prim::kPrimPartial, {nullptr, true}},
{prim::kPrimJ, {InferImplJ, false}},
{prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}},
{prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}},
{prim::kPrimEnvAdd, {InferImplEnvAdd, true}},
{prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}},
{prim::kPrimMakeRef, {InferImplMakeRef, true}},
{prim::kPrimGetRefKey, {InferImplGetRefKey, true}},
{prim::kPrimGetRefValue, {InferImplGetRefValue, true}},
{prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}},
{prim::kPrimStateSetItem, {InferImplStateSetItem, true}},
{prim::kPrimDepend, {InferImplDepend, true}},
{prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}},
{prim::kPrimControlDepend, {InferImplControlDepend, true}},
// Debug
{prim::kPrimScalarSummary, {InferImplScalarSummary, true}},
{prim::kPrimImageSummary, {InferImplTensorSummary, true}},
{prim::kPrimTensorSummary, {InferImplTensorSummary, true}},
};
PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
static PrimitiveEvalImplMap prim_eval_implement_map = {
// Statements
{prim::kPrimReturn, {InferImplReturn, true}},
{prim::kPrimTypeOf, {InferImplTypeof, false}},
{prim::kPrimHasType, {InferImplHasType, false}},
{prim::kPrimDot, {InferImplDot, true}},
{prim::kPrimSwitch, {InferImplSwitch, true}},
{prim::kPrimIs_, {InferImplIs_, true}},
{prim::kPrimIsNot, {InferImplIsNot, true}},
// Maths
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
// Array
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
{prim::kPrimShape, {InferImplShape, true}},
{prim::kPrimPack, {InferImplPack, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeList, {InferImplMakeList, true}},
{prim::kPrimMakeDict, {InferImplMakeDict, true}},
{prim::kPrimMakeSlice, {InferImplMakeSlice, true}},
{prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}},
{prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}},
{prim::kPrimMakeRecord, {InferImplMakeRecord, false}},
{prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}},
{prim::kPrimListGetItem, {InferImplListGetItem, true}},
{prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}},
{prim::kPrimListSetItem, {InferImplListSetItem, true}},
{prim::kPrimDictGetItem, {InferImplDictGetItem, true}},
{prim::kPrimDictSetItem, {InferImplDictSetItem, true}},
{prim::kPrimListAppend, {InferImplListAppend, true}},
{prim::kPrimTupleLen, {InferImplTupleLen, true}},
{prim::kPrimListLen, {InferImplListLen, true}},
{prim::kPrimArrayLen, {InferImplArrayLen, true}},
{prim::kPrimListMap, {InferImplListMap, false}},
{prim::kPrimListReduce, {InferImplListReduce, false}},
{prim::kPrimTupleReversed, {InferImplTupleReversed, false}},
{prim::kPrimReducedShape, {InferImplReduceShape, false}},
{prim::kPrimTupleDiv, {InferImplTupleDiv, false}},
{prim::kPrimTupleToArray, {InferImplTuple2Array, false}},
{prim::kPrimShapeMul, {InferImplShapeMul, false}},
{prim::kPrimTupleEqual, {InferImplTupleEqual, false}},
{prim::kPrimListEqual, {InferImplListEqual, false}},
{prim::kPrimMakeRange, {InferImplMakeRange, false}},
{prim::kPrimStopGradient, {InferImplStopGradient, false}},
{prim::kPrimStringEqual, {InferImplStringEqual, false}},
{prim::kPrimDictLen, {InferImplDictLen, false}},
// NN
{prim::kPrimPooling, {InferImplPooling, true}},
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
{prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}},
{prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
{prim::kPrimReluGrad, {InferImplReluGrad, true}},
{prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}},
{prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}},
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},
{prim::kPrimRelu, {InferImplRelu, true}},
{prim::kPrimZerosLikeTensor, {InferImplZerosLikeTensor, true}},
{prim::kPrimFakeBprop, {InferImplFakeBprop, false}},
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
// Others
{prim::kPrimIdentity, {InferImplIdentity, true}},
// Set impl to null as it will use PartialEvaluator;
{prim::kPrimPartial, {nullptr, true}},
{prim::kPrimJ, {InferImplJ, false}},
{prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}},
{prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}},
{prim::kPrimEnvAdd, {InferImplEnvAdd, true}},
{prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}},
{prim::kPrimMakeRef, {InferImplMakeRef, true}},
{prim::kPrimGetRefKey, {InferImplGetRefKey, true}},
{prim::kPrimGetRefValue, {InferImplGetRefValue, true}},
{prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}},
{prim::kPrimStateSetItem, {InferImplStateSetItem, true}},
{prim::kPrimDepend, {InferImplDepend, true}},
{prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}},
{prim::kPrimControlDepend, {InferImplControlDepend, true}},
// Debug
{prim::kPrimScalarSummary, {InferImplScalarSummary, true}},
{prim::kPrimImageSummary, {InferImplTensorSummary, true}},
{prim::kPrimTensorSummary, {InferImplTensorSummary, true}},
};
return prim_eval_implement_map;
}
using mindspore::parse::PyObjectWrapper;
@ -961,10 +964,7 @@ class PartialEvaluator : public Evaluator {
new_nodes_inputs[1] = NewValueNode(new_signature_value);
FuncGraphPtr func_graph = cnode->func_graph();
ScopePtr scope = kDefaultScope;
if (out_conf != nullptr) {
scope = out_conf->node()->scope();
}
ScopePtr scope = out_conf->node()->scope();
ScopeGuard scope_guard(scope);
CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs);
@ -981,39 +981,41 @@ struct PrimitiveImplInferValue {
};
using PrimitiveToImplMap = std::unordered_map<PrimitivePtr, PrimitiveImplInferValue, PrimitiveHasher, PrimitiveEqual>;
PrimitiveToImplMap UniformPrimitiveToImplMapValue = {
{prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}},
{prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}},
{prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}},
{prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}},
{prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}},
{prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}},
{prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}},
{prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}},
{prim::kPrimScalarEq, {prim::ScalarEq, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarLt, {prim::ScalarLt, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarGt, {prim::ScalarGt, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarNe, {prim::ScalarNe, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarLe, {prim::ScalarLe, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarGe, {prim::ScalarGe, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolNot, {prim::BoolNot, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared<Bool>(), true}},
};
PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
static PrimitiveToImplMap uniform_prim_implement_map = {
{prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}},
{prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}},
{prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}},
{prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}},
{prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}},
{prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}},
{prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}},
{prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}},
{prim::kPrimScalarEq, {prim::ScalarEq, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarLt, {prim::ScalarLt, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarGt, {prim::ScalarGt, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarNe, {prim::ScalarNe, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarLe, {prim::ScalarLe, true, std::make_shared<Bool>(), true}},
{prim::kPrimScalarGe, {prim::ScalarGe, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolNot, {prim::BoolNot, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared<Bool>(), true}},
};
return uniform_prim_implement_map;
}
PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap();
std::mutex PrimEvaluatorConstructorMutex;
void InitPrimEvaluatorConstructors(const PrimitiveEvalImplMap &prim_eval_impl_map) {
void InitPrimEvaluatorConstructors() {
PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
for (const auto &iter : prim_eval_impl_map) {
for (const auto &iter : GetPrimitiveToEvalImplMap()) {
constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_);
}
for (const auto &iter : UniformPrimitiveToImplMapValue) {
for (const auto &iter : GetUniformPrimitiveToImplMap()) {
constructor[iter.first] =
InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_);
}
@ -1028,20 +1030,20 @@ void InitPrimEvaluatorConstructors(const PrimitiveEvalImplMap &prim_eval_impl_ma
void ClearPrimEvaluatorMap() {
PrimEvaluatorConstructors.clear();
PrimitiveToInferImplMap.clear();
UniformPrimitiveToImplMapValue.clear();
GetPrimitiveToEvalImplMap().clear();
GetUniformPrimitiveToImplMap().clear();
}
bool IsInWhiteList(const PrimitivePtr primitive) {
MS_EXCEPTION_IF_NULL(primitive);
auto iter = PrimitiveToInferImplMap.find(primitive);
if (iter != PrimitiveToInferImplMap.end()) {
auto iter = GetPrimitiveToEvalImplMap().find(primitive);
if (iter != GetPrimitiveToEvalImplMap().end()) {
return iter->second.in_white_list_;
}
auto uni_iter = UniformPrimitiveToImplMapValue.find(primitive);
if (uni_iter != UniformPrimitiveToImplMapValue.end()) {
auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive);
if (uni_iter != GetUniformPrimitiveToImplMap().end()) {
return uni_iter->second.in_white_list_;
}
@ -1050,8 +1052,8 @@ bool IsInWhiteList(const PrimitivePtr primitive) {
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
MS_EXCEPTION_IF_NULL(primitive);
auto iter = PrimitiveToInferImplMap.find(primitive);
if (iter == PrimitiveToInferImplMap.end()) {
auto iter = GetPrimitiveToEvalImplMap().find(primitive);
if (iter == GetPrimitiveToEvalImplMap().end()) {
return nullptr;
}
return iter->second.impl_;
@ -1064,7 +1066,7 @@ PrimEvaluatorMap &GetPrimEvaluatorConstructors() {
}
std::lock_guard<std::mutex> initLock(PrimEvaluatorConstructorMutex);
if (constructor.empty()) {
InitPrimEvaluatorConstructors(PrimitiveToInferImplMap);
InitPrimEvaluatorConstructors();
}
return constructor;

View File

@ -296,38 +296,35 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
if (prim->HasPyEvaluator()) {
auto prim_py = dyn_cast<PrimitivePy>(prim);
if (prim_py != nullptr) {
evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
} else {
MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive.";
return std::make_shared<PythonPrimEvaluator>(prim_py);
}
MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive.";
}
if (prim->isa<PrimitivePy>() || prim->HasAttr()) {
if (engine == nullptr) {
(void)GetPrimEvaluatorConstructors();
}
} else if (prim->isa<PrimitivePy>() || prim->HasAttr()) {
// If a primitive may have attr, try to create a new evaluator.
StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim);
if (eval_impl != nullptr) {
std::shared_ptr<StandardPrimEvaluator> standard_evaluator =
std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
evaluator = standard_evaluator;
return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
}
}
if (evaluator == nullptr) {
if (engine == nullptr) {
// If engine is nullptr, get constructor from default.
const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors();
auto iter = prim_evaluator_map.find(prim);
if (iter == prim_evaluator_map.end()) {
evaluator = nullptr;
} else {
evaluator = iter->second;
}
} else {
// If engine is given, get constructor from engine resource.
const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors();
auto iter = prim_evaluator_map.find(prim);
if (iter == prim_evaluator_map.end()) {
evaluator = nullptr;
} else {
evaluator = iter->second;
}
if (engine == nullptr) {
// If engine is nullptr, get constructor from default.
const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors();
auto iter = prim_evaluator_map.find(prim);
if (iter != prim_evaluator_map.end()) {
evaluator = iter->second;
}
} else {
// If engine is given, get constructor from engine resource.
const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors();
auto iter = prim_evaluator_map.find(prim);
if (iter != prim_evaluator_map.end()) {
evaluator = iter->second;
}
}
if (evaluator == nullptr) {