Refine unify runtime context

This commit is contained in:
lizhenyu 2021-07-10 17:52:42 +08:00
parent 6ab9904f33
commit c50606ef26
4 changed files with 20 additions and 0 deletions

View File

@ -398,6 +398,20 @@ void PrimitivePyAdapter::AddPyAttr(const py::str &name, const py::object &obj) {
if (prim != nullptr) {
prim->AddAttr(attr_name, converted_ret);
}
if (attr_name == "primitive_target") {
MS_EXCEPTION_IF_NULL(converted_ret);
if (!converted_ret->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
}
auto target = GetValue<std::string>(converted_ret);
if (target != kCPUDevice && target != kGPUDevice) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
context_ptr->set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, true);
}
}
}
void PrimitivePyAdapter::DelPyAttr(const py::str &name) {

View File

@ -566,6 +566,10 @@ BackendPtr CreateBackend() {
void SetMindRTEnable() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT)) {
return;
}
std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if ((target != kGPUDevice) && (target != kCPUDevice)) {
return;

View File

@ -87,6 +87,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<bool>(MS_CTX_SAVE_COMPILE_CACHE, false);
set_param<bool>(MS_CTX_LOAD_COMPILE_CACHE, false);
set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, false);
backend_policy_ = policy_map_[policy];
}

View File

@ -89,6 +89,7 @@ enum MsCtxParam : unsigned {
MS_CTX_SAVE_COMPILE_CACHE,
MS_CTX_LOAD_COMPILE_CACHE,
MS_CTX_ENABLE_MINDRT,
MS_CTX_ALREADY_SET_ENABLE_MINDRT,
MS_CTX_TYPE_BOOL_END,
// parameter of type int