!49330 modify GE context

Merge pull request !49330 from 没有窗户的小巷/master
This commit is contained in:
i-robot 2023-02-24 06:51:08 +00:00 committed by Gitee
commit fe4e15de38
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 19 additions and 10 deletions

View File

@ -649,14 +649,6 @@ void GeDeviceContext::GetGeOptions(const std::shared_ptr<MsContext> &ms_context_
} }
(*ge_options)["graphType"] = "1"; (*ge_options)["graphType"] = "1";
if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
(*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
}
if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
(*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
}
bool training = IsGeTrain(); bool training = IsGeTrain();
if (training) { if (training) {
(*ge_options)["ge.graphRunMode"] = "1"; (*ge_options)["ge.graphRunMode"] = "1";
@ -792,6 +784,18 @@ bool GeDeviceContext::FinalizeGe(const std::shared_ptr<MsContext> &inst_context)
return true; return true;
} }
void GeDeviceResManager::GeSetContextOptions(const std::shared_ptr<MsContext> &ms_context_ptr,
transform::SessionOptions *options) {
MS_EXCEPTION_IF_NULL(options);
if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
(*options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
}
if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
(*options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
}
}
void GeDeviceResManager::CreateSessionAndGraphRunner(bool is_training) { void GeDeviceResManager::CreateSessionAndGraphRunner(bool is_training) {
std::shared_ptr<::ge::Session> sess = transform::GetGeSession(); std::shared_ptr<::ge::Session> sess = transform::GetGeSession();
if (sess == nullptr) { if (sess == nullptr) {
@ -806,6 +810,10 @@ void GeDeviceResManager::CreateSessionAndGraphRunner(bool is_training) {
} }
options["ge.enablePrintOpPass"] = "0"; options["ge.enablePrintOpPass"] = "0";
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
GeSetContextOptions(ms_context, &options);
sess = transform::NewSession(options); sess = transform::NewSession(options);
transform::SetGeSession(sess); transform::SetGeSession(sess);
} }

View File

@ -56,6 +56,7 @@ class GeDeviceResManager : public DeviceResManager {
void FreeMemory(void *ptr) const override; void FreeMemory(void *ptr) const override;
private: private:
static void GeSetContextOptions(const std::shared_ptr<MsContext> &ms_context_ptr, transform::SessionOptions *options);
std::shared_ptr<MemoryManager> mem_manager_; std::shared_ptr<MemoryManager> mem_manager_;
}; };

View File

@ -202,7 +202,7 @@ class BatchNormInfer : public abstract::OpInferBase {
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterThan, 0, prim_name); (void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterThan, 0, prim_name);
const std::set valid_types = {kFloat16, kFloat32}; const std::set valid_types = {kFloat16, kFloat32};
auto x_type = input_args[0]->BuildType(); auto x_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim->name()); (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim_name);
std::map<std::string, TypePtr> types; std::map<std::string, TypePtr> types;
auto scale_type = input_args[kInputIndex1]->BuildType(); auto scale_type = input_args[kInputIndex1]->BuildType();
@ -212,7 +212,7 @@ class BatchNormInfer : public abstract::OpInferBase {
(void)types.emplace("mean", input_args[kInputIndex3]->BuildType()); (void)types.emplace("mean", input_args[kInputIndex3]->BuildType());
(void)types.emplace("variance", input_args[kInputIndex4]->BuildType()); (void)types.emplace("variance", input_args[kInputIndex4]->BuildType());
} }
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); (void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{x_type, scale_type, scale_type, scale_type, scale_type}); return std::make_shared<Tuple>(std::vector<TypePtr>{x_type, scale_type, scale_type, scale_type, scale_type});
} }
}; };