!65 fix bug of global variable initialization order is not guaranteed during compilation

Merge pull request !65 from mxm/for_prim_map_init
This commit is contained in:
mindspore-ci-bot 2020-04-09 10:19:50 +08:00 committed by Gitee
commit 086edc133f
2 changed files with 148 additions and 149 deletions

View File

@ -42,7 +42,8 @@
namespace mindspore {
namespace abstract {
PrimitiveEvalImplMap PrimitiveToInferImplMap = {
PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
static PrimitiveEvalImplMap prim_eval_implement_map = {
// Statements
{prim::kPrimReturn, {InferImplReturn, true}},
{prim::kPrimTypeOf, {InferImplTypeof, false}},
@ -127,7 +128,9 @@ PrimitiveEvalImplMap PrimitiveToInferImplMap = {
{prim::kPrimScalarSummary, {InferImplScalarSummary, true}},
{prim::kPrimImageSummary, {InferImplTensorSummary, true}},
{prim::kPrimTensorSummary, {InferImplTensorSummary, true}},
};
};
return prim_eval_implement_map;
}
using mindspore::parse::PyObjectWrapper;
@ -961,10 +964,7 @@ class PartialEvaluator : public Evaluator {
new_nodes_inputs[1] = NewValueNode(new_signature_value);
FuncGraphPtr func_graph = cnode->func_graph();
ScopePtr scope = kDefaultScope;
if (out_conf != nullptr) {
scope = out_conf->node()->scope();
}
ScopePtr scope = out_conf->node()->scope();
ScopeGuard scope_guard(scope);
CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs);
@ -981,8 +981,8 @@ struct PrimitiveImplInferValue {
};
using PrimitiveToImplMap = std::unordered_map<PrimitivePtr, PrimitiveImplInferValue, PrimitiveHasher, PrimitiveEqual>;
PrimitiveToImplMap UniformPrimitiveToImplMapValue = {
PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
static PrimitiveToImplMap uniform_prim_implement_map = {
{prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}},
{prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}},
{prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}},
@ -1001,19 +1001,21 @@ PrimitiveToImplMap UniformPrimitiveToImplMapValue = {
{prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared<Bool>(), true}},
{prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared<Bool>(), true}},
};
};
return uniform_prim_implement_map;
}
PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap();
std::mutex PrimEvaluatorConstructorMutex;
void InitPrimEvaluatorConstructors(const PrimitiveEvalImplMap &prim_eval_impl_map) {
void InitPrimEvaluatorConstructors() {
PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
for (const auto &iter : prim_eval_impl_map) {
for (const auto &iter : GetPrimitiveToEvalImplMap()) {
constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_);
}
for (const auto &iter : UniformPrimitiveToImplMapValue) {
for (const auto &iter : GetUniformPrimitiveToImplMap()) {
constructor[iter.first] =
InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_);
}
@ -1028,20 +1030,20 @@ void InitPrimEvaluatorConstructors(const PrimitiveEvalImplMap &prim_eval_impl_ma
void ClearPrimEvaluatorMap() {
PrimEvaluatorConstructors.clear();
PrimitiveToInferImplMap.clear();
UniformPrimitiveToImplMapValue.clear();
GetPrimitiveToEvalImplMap().clear();
GetUniformPrimitiveToImplMap().clear();
}
bool IsInWhiteList(const PrimitivePtr primitive) {
MS_EXCEPTION_IF_NULL(primitive);
auto iter = PrimitiveToInferImplMap.find(primitive);
if (iter != PrimitiveToInferImplMap.end()) {
auto iter = GetPrimitiveToEvalImplMap().find(primitive);
if (iter != GetPrimitiveToEvalImplMap().end()) {
return iter->second.in_white_list_;
}
auto uni_iter = UniformPrimitiveToImplMapValue.find(primitive);
if (uni_iter != UniformPrimitiveToImplMapValue.end()) {
auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive);
if (uni_iter != GetUniformPrimitiveToImplMap().end()) {
return uni_iter->second.in_white_list_;
}
@ -1050,8 +1052,8 @@ bool IsInWhiteList(const PrimitivePtr primitive) {
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
MS_EXCEPTION_IF_NULL(primitive);
auto iter = PrimitiveToInferImplMap.find(primitive);
if (iter == PrimitiveToInferImplMap.end()) {
auto iter = GetPrimitiveToEvalImplMap().find(primitive);
if (iter == GetPrimitiveToEvalImplMap().end()) {
return nullptr;
}
return iter->second.impl_;
@ -1064,7 +1066,7 @@ PrimEvaluatorMap &GetPrimEvaluatorConstructors() {
}
std::lock_guard<std::mutex> initLock(PrimEvaluatorConstructorMutex);
if (constructor.empty()) {
InitPrimEvaluatorConstructors(PrimitiveToInferImplMap);
InitPrimEvaluatorConstructors();
}
return constructor;

View File

@ -296,40 +296,37 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
if (prim->HasPyEvaluator()) {
auto prim_py = dyn_cast<PrimitivePy>(prim);
if (prim_py != nullptr) {
evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
} else {
return std::make_shared<PythonPrimEvaluator>(prim_py);
}
MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive.";
}
} else if (prim->isa<PrimitivePy>() || prim->HasAttr()) {
if (prim->isa<PrimitivePy>() || prim->HasAttr()) {
if (engine == nullptr) {
(void)GetPrimEvaluatorConstructors();
}
// If a primitive may have attr, try to create a new evaluator.
StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim);
if (eval_impl != nullptr) {
std::shared_ptr<StandardPrimEvaluator> standard_evaluator =
std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
evaluator = standard_evaluator;
return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
}
}
if (evaluator == nullptr) {
if (engine == nullptr) {
// If engine is nullptr, get constructor from default.
const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors();
auto iter = prim_evaluator_map.find(prim);
if (iter == prim_evaluator_map.end()) {
evaluator = nullptr;
} else {
if (iter != prim_evaluator_map.end()) {
evaluator = iter->second;
}
} else {
// If engine is given, get constructor from engine resource.
const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors();
auto iter = prim_evaluator_map.find(prim);
if (iter == prim_evaluator_map.end()) {
evaluator = nullptr;
} else {
if (iter != prim_evaluator_map.end()) {
evaluator = iter->second;
}
}
}
if (evaluator == nullptr) {
MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << prim->name() << ").";
}