!49330 modify GE context

Merge pull request !49330 from 没有窗户的小巷/master
This commit is contained in:
i-robot 2023-02-24 06:51:08 +00:00 committed by Gitee
commit fe4e15de38
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 19 additions and 10 deletions

View File

@ -649,14 +649,6 @@ void GeDeviceContext::GetGeOptions(const std::shared_ptr<MsContext> &ms_context_
}
(*ge_options)["graphType"] = "1";
if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
(*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
}
if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
(*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
}
bool training = IsGeTrain();
if (training) {
(*ge_options)["ge.graphRunMode"] = "1";
@ -792,6 +784,18 @@ bool GeDeviceContext::FinalizeGe(const std::shared_ptr<MsContext> &inst_context)
return true;
}
void GeDeviceResManager::GeSetContextOptions(const std::shared_ptr<MsContext> &ms_context_ptr,
transform::SessionOptions *options) {
MS_EXCEPTION_IF_NULL(options);
if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
(*options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
}
if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
(*options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
}
}
void GeDeviceResManager::CreateSessionAndGraphRunner(bool is_training) {
std::shared_ptr<::ge::Session> sess = transform::GetGeSession();
if (sess == nullptr) {
@ -806,6 +810,10 @@ void GeDeviceResManager::CreateSessionAndGraphRunner(bool is_training) {
}
options["ge.enablePrintOpPass"] = "0";
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
GeSetContextOptions(ms_context, &options);
sess = transform::NewSession(options);
transform::SetGeSession(sess);
}

View File

@ -56,6 +56,7 @@ class GeDeviceResManager : public DeviceResManager {
void FreeMemory(void *ptr) const override;
private:
static void GeSetContextOptions(const std::shared_ptr<MsContext> &ms_context_ptr, transform::SessionOptions *options);
std::shared_ptr<MemoryManager> mem_manager_;
};

View File

@ -202,7 +202,7 @@ class BatchNormInfer : public abstract::OpInferBase {
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterThan, 0, prim_name);
const std::set valid_types = {kFloat16, kFloat32};
auto x_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim_name);
std::map<std::string, TypePtr> types;
auto scale_type = input_args[kInputIndex1]->BuildType();
@ -212,7 +212,7 @@ class BatchNormInfer : public abstract::OpInferBase {
(void)types.emplace("mean", input_args[kInputIndex3]->BuildType());
(void)types.emplace("variance", input_args[kInputIndex4]->BuildType());
}
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{x_type, scale_type, scale_type, scale_type, scale_type});
}
};