forked from mindspore-Ecosystem/mindspore
commit
fe4e15de38
|
@ -649,14 +649,6 @@ void GeDeviceContext::GetGeOptions(const std::shared_ptr<MsContext> &ms_context_
|
||||||
}
|
}
|
||||||
(*ge_options)["graphType"] = "1";
|
(*ge_options)["graphType"] = "1";
|
||||||
|
|
||||||
if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
|
|
||||||
(*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
|
|
||||||
(*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool training = IsGeTrain();
|
bool training = IsGeTrain();
|
||||||
if (training) {
|
if (training) {
|
||||||
(*ge_options)["ge.graphRunMode"] = "1";
|
(*ge_options)["ge.graphRunMode"] = "1";
|
||||||
|
@ -792,6 +784,18 @@ bool GeDeviceContext::FinalizeGe(const std::shared_ptr<MsContext> &inst_context)
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GeDeviceResManager::GeSetContextOptions(const std::shared_ptr<MsContext> &ms_context_ptr,
|
||||||
|
transform::SessionOptions *options) {
|
||||||
|
MS_EXCEPTION_IF_NULL(options);
|
||||||
|
if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
|
||||||
|
(*options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
|
||||||
|
(*options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void GeDeviceResManager::CreateSessionAndGraphRunner(bool is_training) {
|
void GeDeviceResManager::CreateSessionAndGraphRunner(bool is_training) {
|
||||||
std::shared_ptr<::ge::Session> sess = transform::GetGeSession();
|
std::shared_ptr<::ge::Session> sess = transform::GetGeSession();
|
||||||
if (sess == nullptr) {
|
if (sess == nullptr) {
|
||||||
|
@ -806,6 +810,10 @@ void GeDeviceResManager::CreateSessionAndGraphRunner(bool is_training) {
|
||||||
}
|
}
|
||||||
|
|
||||||
options["ge.enablePrintOpPass"] = "0";
|
options["ge.enablePrintOpPass"] = "0";
|
||||||
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
GeSetContextOptions(ms_context, &options);
|
||||||
|
|
||||||
sess = transform::NewSession(options);
|
sess = transform::NewSession(options);
|
||||||
transform::SetGeSession(sess);
|
transform::SetGeSession(sess);
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,6 +56,7 @@ class GeDeviceResManager : public DeviceResManager {
|
||||||
void FreeMemory(void *ptr) const override;
|
void FreeMemory(void *ptr) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
static void GeSetContextOptions(const std::shared_ptr<MsContext> &ms_context_ptr, transform::SessionOptions *options);
|
||||||
std::shared_ptr<MemoryManager> mem_manager_;
|
std::shared_ptr<MemoryManager> mem_manager_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -202,7 +202,7 @@ class BatchNormInfer : public abstract::OpInferBase {
|
||||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterThan, 0, prim_name);
|
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterThan, 0, prim_name);
|
||||||
const std::set valid_types = {kFloat16, kFloat32};
|
const std::set valid_types = {kFloat16, kFloat32};
|
||||||
auto x_type = input_args[0]->BuildType();
|
auto x_type = input_args[0]->BuildType();
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim->name());
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim_name);
|
||||||
|
|
||||||
std::map<std::string, TypePtr> types;
|
std::map<std::string, TypePtr> types;
|
||||||
auto scale_type = input_args[kInputIndex1]->BuildType();
|
auto scale_type = input_args[kInputIndex1]->BuildType();
|
||||||
|
@ -212,7 +212,7 @@ class BatchNormInfer : public abstract::OpInferBase {
|
||||||
(void)types.emplace("mean", input_args[kInputIndex3]->BuildType());
|
(void)types.emplace("mean", input_args[kInputIndex3]->BuildType());
|
||||||
(void)types.emplace("variance", input_args[kInputIndex4]->BuildType());
|
(void)types.emplace("variance", input_args[kInputIndex4]->BuildType());
|
||||||
}
|
}
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
|
||||||
return std::make_shared<Tuple>(std::vector<TypePtr>{x_type, scale_type, scale_type, scale_type, scale_type});
|
return std::make_shared<Tuple>(std::vector<TypePtr>{x_type, scale_type, scale_type, scale_type, scale_type});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue