forked from mindspore-Ecosystem/mindspore
commit
fe4e15de38
|
@ -649,14 +649,6 @@ void GeDeviceContext::GetGeOptions(const std::shared_ptr<MsContext> &ms_context_
|
|||
}
|
||||
(*ge_options)["graphType"] = "1";
|
||||
|
||||
if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
|
||||
(*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
|
||||
}
|
||||
|
||||
if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
|
||||
(*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
|
||||
}
|
||||
|
||||
bool training = IsGeTrain();
|
||||
if (training) {
|
||||
(*ge_options)["ge.graphRunMode"] = "1";
|
||||
|
@ -792,6 +784,18 @@ bool GeDeviceContext::FinalizeGe(const std::shared_ptr<MsContext> &inst_context)
|
|||
return true;
|
||||
}
|
||||
|
||||
void GeDeviceResManager::GeSetContextOptions(const std::shared_ptr<MsContext> &ms_context_ptr,
|
||||
transform::SessionOptions *options) {
|
||||
MS_EXCEPTION_IF_NULL(options);
|
||||
if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
|
||||
(*options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
|
||||
}
|
||||
|
||||
if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
|
||||
(*options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
void GeDeviceResManager::CreateSessionAndGraphRunner(bool is_training) {
|
||||
std::shared_ptr<::ge::Session> sess = transform::GetGeSession();
|
||||
if (sess == nullptr) {
|
||||
|
@ -806,6 +810,10 @@ void GeDeviceResManager::CreateSessionAndGraphRunner(bool is_training) {
|
|||
}
|
||||
|
||||
options["ge.enablePrintOpPass"] = "0";
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
GeSetContextOptions(ms_context, &options);
|
||||
|
||||
sess = transform::NewSession(options);
|
||||
transform::SetGeSession(sess);
|
||||
}
|
||||
|
|
|
@ -56,6 +56,7 @@ class GeDeviceResManager : public DeviceResManager {
|
|||
void FreeMemory(void *ptr) const override;
|
||||
|
||||
private:
|
||||
static void GeSetContextOptions(const std::shared_ptr<MsContext> &ms_context_ptr, transform::SessionOptions *options);
|
||||
std::shared_ptr<MemoryManager> mem_manager_;
|
||||
};
|
||||
|
||||
|
|
|
@ -202,7 +202,7 @@ class BatchNormInfer : public abstract::OpInferBase {
|
|||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterThan, 0, prim_name);
|
||||
const std::set valid_types = {kFloat16, kFloat32};
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim_name);
|
||||
|
||||
std::map<std::string, TypePtr> types;
|
||||
auto scale_type = input_args[kInputIndex1]->BuildType();
|
||||
|
@ -212,7 +212,7 @@ class BatchNormInfer : public abstract::OpInferBase {
|
|||
(void)types.emplace("mean", input_args[kInputIndex3]->BuildType());
|
||||
(void)types.emplace("variance", input_args[kInputIndex4]->BuildType());
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{x_type, scale_type, scale_type, scale_type, scale_type});
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue