!5352 refactor ms_context implementation

Merge pull request !5352 from fary86/refactor_context_interface
This commit is contained in:
mindspore-ci-bot 2020-08-31 09:27:33 +08:00 committed by Gitee
commit 8d41931456
77 changed files with 582 additions and 554 deletions

View File

@ -25,7 +25,7 @@ bool HcomAllBroadCastKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> & /*outputs*/, void *stream_ptr) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_task_sink()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
return true;
}
if (inputs.empty() || hccl_data_type_list_.empty()) {

View File

@ -24,7 +24,7 @@ bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &inputs, const st
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_task_sink()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
return true;
}
if (inputs.empty() || hccl_data_type_list_.empty()) {

View File

@ -24,7 +24,7 @@ bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const st
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_task_sink()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
return true;
}
if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) {

View File

@ -25,7 +25,7 @@ bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_task_sink()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
return true;
}
if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) {

View File

@ -101,7 +101,8 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL) &&
IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) {
kernel_type = KernelType::AKG_KERNEL;
}

View File

@ -328,7 +328,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im
}
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool is_gpu = (context->device_target() == kGPUDevice);
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) {
MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
<< ", current op num: " << op_info_.size();

View File

@ -249,8 +249,8 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -262,7 +262,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
}
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
if (context_ptr->execution_mode() == kPynativeMode) {
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
} else {
@ -276,7 +276,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
AddAscendIRFusionRulesPass(ir_fusion_pm.get());
AddAscendIRFusionPass(ir_fusion_pm.get());
if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) &&
ConfigManager::GetInstance().iter_num() > 1) {
ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForGetNext>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
@ -296,12 +297,12 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->ir_fusion_flag()) {
if (!context_ptr->get_param<bool>(MS_CTX_IR_FUSION_FLAG)) {
MS_LOG(INFO) << "IRFusion is not enable, skip";
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -331,8 +332,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -367,7 +368,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
auto other2_pm = std::make_shared<PassManager>("other2_pm");
other2_pm->AddPass(std::make_shared<GetitemTuple>());
other2_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) &&
ConfigManager::GetInstance().iter_num() > 1) {
other2_pm->AddPass(std::make_shared<GetnextMemcpyElimination>());
}
other2_pm->AddPass(std::make_shared<CheckConsistency>());
@ -388,11 +390,11 @@ void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &ke
bool is_before_kernel_select) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) {
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -418,11 +420,11 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kern
bool is_before_kernel_select) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) {
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -447,11 +449,11 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kern
void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) {
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -473,12 +475,12 @@ void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &ke
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->ir_fusion_flag()) {
if (!context_ptr->get_param<bool>(MS_CTX_IR_FUSION_FLAG)) {
MS_LOG(INFO) << "UBFusion is not enable, skip";
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}

View File

@ -53,7 +53,8 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK)) {
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
return new_node;
}

View File

@ -44,7 +44,7 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
auto cnode = node->cast<CNodePtr>();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return RectifyKernelInfoInPynativeProcess(node);
}
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) {

View File

@ -33,8 +33,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id();
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}

View File

@ -392,7 +392,8 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
bool IsNopNode(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->device_target() != kAscendDevice && context_ptr->device_target() != kGPUDevice) {
if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice &&
context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
return false;
}
static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,

View File

@ -40,8 +40,8 @@ bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr>
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}

View File

@ -114,7 +114,7 @@ const AnfNodePtr ConstToAttrStridedSliceGradPass::Process(const FuncGraphPtr &gr
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->device_target() == kAscendDevice) {
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
if (!CheckAttrs(strided_slice_grad)) {
MS_LOG(INFO) << "Check strided_slice_grad's attrs failed, graph not changed";
return nullptr;

View File

@ -359,11 +359,11 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto save_graphs_path = context_ptr->save_graphs_path();
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
if (context_ptr->save_graphs_flag()) {
if (context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
std::string file_path = save_graphs_path + "/after_erase_label_and_parameter.ir";
DumpIR(file_path, root_graph.get());
}

View File

@ -253,7 +253,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
debugger_->PreExecute(graph);
}
#endif
if (ms_context->precompile_only()) {
if (ms_context->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
MS_LOG(INFO) << "Precompile only, stop in build kernel step";
} else {
// alloc memory, including static memory and dynamic memory
@ -278,8 +278,8 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
child_graph->SetExecOrderByDefault();
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -436,7 +436,7 @@ void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kGraphMode) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
if (raise_precision_count > 0) {
MS_LOG(WARNING) << "There has " << raise_precision_count
<< " node/nodes used raise precision to selected the kernel!";
@ -481,8 +481,8 @@ void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_grap
device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -601,11 +601,11 @@ void AscendSession::DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs)
#ifdef ENABLE_DUMP_IR
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
if (!save_graphs) {
return;
}
auto save_graphs_path = context_ptr->save_graphs_path();
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -733,7 +733,7 @@ void AscendSession::MergeGraphExecOrder() {
if (graph_order.size() > 1) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->enable_task_sink()) {
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
MS_LOG(EXCEPTION) << "Control sink network should run with task-sink mode!";
}
}
@ -920,8 +920,8 @@ void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<st
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs) {
if (save_graphs_path.empty()) {
save_graphs_path = ".";
@ -947,7 +947,7 @@ void AscendSession::SelectKernel(NotNull<KernelGraphPtr> root_graph) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kGraphMode) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
if (raise_precision_count > 0) {
MS_LOG(WARNING) << "There are " << raise_precision_count
<< " node/nodes used raise precision to selected the kernel!";
@ -992,8 +992,8 @@ void AscendSession::RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph,
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs) {
if (save_graphs_path.empty()) {
save_graphs_path = ".";

View File

@ -76,7 +76,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
if (!CheckInModeBlackList(kernel_graph) && context_ptr->execution_mode() != kPynativeMode) {
if (!CheckInModeBlackList(kernel_graph) && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
@ -154,7 +154,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
bool need_sync = false;
if (ms_context->enable_pynative_infer()) {
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
if (tensor_address == nullptr || tensor_address != device_address) {
need_sync = true;
}
@ -223,7 +223,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
// Prepare ms context info for dump .pb graph
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
// Optimize
Optimize(graph);
// Select kernel build info
@ -290,7 +290,7 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
// Summary
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_gpu_summary()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY)) {
Summary(kernel_graph.get());
}
#ifdef ENABLE_DEBUGGER

View File

@ -268,7 +268,7 @@ void MSInferSession::RegAllOp() {
return;
}
Initialized = true;
MsContext::GetInstance()->set_execution_mode(kGraphMode);
MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
Py_Initialize();
auto c_expression = PyImport_ImportModule("mindspore._c_expression");
if (c_expression == nullptr) {
@ -357,13 +357,13 @@ Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
MS_LOG(ERROR) << "Get Context failed!";
return FAILED;
}
ms_context->set_execution_mode(kGraphMode);
ms_context->set_device_id(device_id);
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
auto ajust_device = AjustTargetName(device);
if (ajust_device == "") {
return FAILED;
}
ms_context->set_device_target(device);
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, device);
if (!context::OpenTsd(ms_context)) {
MS_LOG(ERROR) << "Session init OpenTsd failed!";
return FAILED;

View File

@ -93,10 +93,11 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
// if in paynative mode,data only copyed to host when user want to print data
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() != kPynativeMode && ms_context->device_target() != kGPUDevice) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
tensor->set_need_sync(true);
}
if (ms_context->execution_mode() != kPynativeMode) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
tensor->SetNeedWait(true);
}
tensor->set_dirty(false);
@ -938,7 +939,7 @@ bool TensorNeedSync(const AnfNodePtr &parameter, const tensor::TensorPtr &tensor
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0);
if (ms_context->enable_pynative_infer()) {
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
return tensor->device_address().get() == nullptr || tensor->device_address() != device_address;
}
if (tensor->is_dirty()) {
@ -979,7 +980,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
if (ms_context->execution_mode() == kPynativeMode ||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
tensor->set_device_address(device_address);
}
@ -1177,7 +1178,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
if (backend_anf != nullptr) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->execution_mode() == kPynativeMode) {
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return backend_anf;
}

View File

@ -118,7 +118,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
debugger_ = Debugger::GetInstance();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
debugger_->Init(device_id_, ms_context->device_target());
debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET));
}
#endif

View File

@ -53,7 +53,7 @@ bool DataDumpParser::DumpEnabled() const {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
if (context->execution_mode() == kPynativeMode) {
if (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
MS_LOG(EXCEPTION) << "[DataDump] PyNative mode not support data dump";
}
return true;

View File

@ -142,7 +142,7 @@ void Debugger::EnableDebugger() {
// switch memory reuse on or off
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
context_ptr->set_enable_mem_reuse(partial_memory_);
context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, partial_memory_);
// print some message about memory reuse to user
if (partial_memory_) {
MS_LOG(WARNING) << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "

View File

@ -530,7 +530,7 @@ void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) {
MS_LOG(ERROR) << "ms_context is nullptr";
return;
}
auto save_graphs_path = ms_context->save_graphs_path();
auto save_graphs_path = ms_context->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}

View File

@ -112,7 +112,7 @@ bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
// dump_enable_ is true, close mem reuse
context_ptr->set_enable_mem_reuse(!dump_enable_);
context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, !dump_enable_);
trans_flag_ = trans_flag;
dump_mode_ = mode;
dump_path_ = path;
@ -135,7 +135,7 @@ bool Dump::SetDumpConfFromJsonFile() {
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto id = context_ptr->device_id();
auto id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
char real_path[PATH_MAX] = {0};
if (nullptr == realpath(config_path_str, real_path)) {
MS_LOG(ERROR) << "Env e2e dump path error, " << config_path_str;

View File

@ -34,7 +34,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
manager_ptr->AddFuncGraph(func_graph);
auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) {
if (MsContext::GetInstance()->is_multi_graph_sink()) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
}

View File

@ -182,7 +182,7 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp
void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool check_bprop_flag = context->check_bprop_flag();
bool check_bprop_flag = context->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG);
// Skip checking if check_bprop not set
if (!check_bprop_flag) {
return;

View File

@ -29,7 +29,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
PConstant const_2(node);
PConstant any_const(node);
if (MsContext::GetInstance()->execution_mode() != kPynativeMode) {
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
MATCH_REPLACE(node, x + zero_, x); // Add by zero
MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero
@ -41,7 +41,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
}
// Prim Eliminate (identity)
MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x);
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return nullptr;
}
@ -75,7 +75,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
}
AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return nullptr;
}
PatternNode x, y;

View File

@ -181,7 +181,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
}
};
use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func();
if (is_on_debug_ && MsContext::GetInstance()->save_graphs_flag()) {
if (is_on_debug_ && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
auto fg_name =
"opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];

View File

@ -217,8 +217,8 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
void DrawNode(string name, AnfNodePtr node) {
auto context_ptr = MsContext::GetInstance();
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}

View File

@ -44,7 +44,7 @@ std::vector<PrimitivePtr> FindPrimtive(const FuncGraphPtr &graph, const std::str
}
void DumpGraph(const FuncGraphPtr &root, const std::string &name) {
if (MsContext::GetInstance()->save_graphs_flag()) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
draw::Draw(name + ".dot", root);
DumpIR(name + ".ir", root);
ExportIR(name + ".dat", "0", root);

View File

@ -69,7 +69,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
if (MsContext::GetInstance()->save_graphs_flag()) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root);
}
MS_LOG(INFO) << "Now entering step auto parallel";

View File

@ -271,7 +271,7 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
if (!result) {
MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first;
}
if (MsContext::GetInstance()->save_graphs_flag() && res->func_graph() != nullptr) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && res->func_graph() != nullptr) {
auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first;
auto func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
@ -295,20 +295,20 @@ bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res,
static bool IsCtrlSink() {
auto ms_ctx = MsContext::GetInstance();
if (ms_ctx->execution_mode() != kGraphMode) {
if (ms_ctx->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode) {
return false;
}
std::string device_target = ms_ctx->device_target();
std::string device_target = ms_ctx->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_target != kAscendDevice) {
return false;
}
if (!ms_ctx->enable_task_sink()) {
if (!ms_ctx->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
return false;
}
if (!ms_ctx->is_multi_graph_sink()) {
if (!ms_ctx->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
return false;
}
return true;
@ -325,13 +325,13 @@ bool TaskEmitAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(context_ptr);
if (CompileGraphs::ContainMixedTarget(func_graph)) {
bc_ptr->set_is_multi_graph_sink(false);
context_ptr->set_is_multi_graph_sink(false);
context_ptr->set_loop_sink_flag(false);
} else if (context_ptr->execution_mode() != kPynativeMode) {
std::string device_target = context_ptr->device_target();
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
} else if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_target == kAscendDevice && backend != kMsVm) {
bc_ptr->set_is_multi_graph_sink(true);
context_ptr->set_is_multi_graph_sink(true);
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
}
}

View File

@ -49,7 +49,7 @@ inline std::string GetFilePathName(const std::string &file_name) {
if (ms_context == nullptr) {
MS_LOG(EXCEPTION) << "ms_context is nullptr";
}
auto save_graphs_path = ms_context->save_graphs_path();
auto save_graphs_path = ms_context->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}

View File

@ -48,6 +48,56 @@ using OpLib = mindspore::kernel::OpLib;
using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy;
using ParallelContext = mindspore::parallel::ParallelContext;
using CostModelContext = mindspore::parallel::CostModelContext;
using mindspore::MsCtxParam;
namespace mindspore {
void MsCtxSetParameter(std::shared_ptr<MsContext> ctx, MsCtxParam param, const py::object &value) {
MS_LOG(DEBUG) << "set param(" << param << ") with value '" << py::str(value) << "' of type '"
<< py::str(value.get_type()) << "'.";
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END && py::isinstance<py::bool_>(value)) {
ctx->set_param<bool>(param, value.cast<bool>());
return;
}
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END && py::isinstance<py::int_>(value)) {
ctx->set_param<int>(param, value.cast<int>());
return;
}
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END && py::isinstance<py::int_>(value)) {
ctx->set_param<uint32_t>(param, value.cast<uint32_t>());
return;
}
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END && py::isinstance<py::float_>(value)) {
ctx->set_param<float>(param, value.cast<float>());
return;
}
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END && py::isinstance<py::str>(value)) {
ctx->set_param<std::string>(param, value.cast<std::string>());
return;
}
MS_LOG(EXCEPTION) << "Got illegal param " << param << " and value with type " << py::str(value.get_type());
}
py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam param) {
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END) {
return py::bool_(ctx->get_param<bool>(param));
}
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END) {
return py::int_(ctx->get_param<int>(param));
}
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END) {
return py::int_(ctx->get_param<uint32_t>(param));
}
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END) {
return py::float_(ctx->get_param<float>(param));
}
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END) {
return py::str(ctx->get_param<std::string>(param));
}
MS_LOG(EXCEPTION) << "Got illegal param " << param << ".";
}
} // namespace mindspore
// Interface with python
PYBIND11_MODULE(_c_expression, m) {
@ -101,53 +151,48 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
(void)m.def("ms_ctx_get_param", &mindspore::MsCtxGetParameter, "Get value of specified paramter.");
(void)m.def("ms_ctx_set_param", &mindspore::MsCtxSetParameter, "Set value for specified paramter.");
(void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic())
.value("auto_mixed_precision_flag", MsCtxParam::MS_CTX_AUTO_MIXED_PRECISION_FLAG)
.value("check_bprop_flag", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG)
.value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP)
.value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL)
.value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY)
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
.value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL)
.value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK)
.value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE)
.value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK)
.value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER)
.value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION)
.value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE)
.value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK)
.value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG)
.value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK)
.value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT)
.value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY)
.value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING)
.value("save_graphs_flag", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
.value("execution_mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
.value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
.value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH)
.value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS)
.value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH)
.value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH)
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
.value("ge_ref", MsCtxParam::MS_CTX_GE_REF)
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
.value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF);
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
.def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.")
.def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.")
.def("get_execution_mode", &mindspore::MsContext::execution_mode, "Get execution mode.")
.def("set_execution_mode", &mindspore::MsContext::set_execution_mode, "Set execution mode.")
.def("set_precompile_only", &mindspore::MsContext::set_precompile_only, "Set enable precompile only.")
.def("get_precompile_only", &mindspore::MsContext::precompile_only, "Get enable precompile only.")
.def("get_device_target", &mindspore::MsContext::device_target, "Get device target.")
.def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.")
.def("get_device_id", &mindspore::MsContext::device_id, "Get device id.")
.def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.")
.def("get_max_call_depth", &mindspore::MsContext::max_call_depth, "Get max call depth.")
.def("set_max_call_depth", &mindspore::MsContext::set_max_call_depth, "Set max call depth.")
.def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.")
.def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.")
.def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag,
"Get whether to enable auto mixed precision.")
.def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag,
"Set whether to enable auto mixed precision.")
.def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision,
"Get whether to enable reduce precision.")
.def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision,
"Set whether to enable reduce precision.")
.def("get_save_graphs_path", &mindspore::MsContext::save_graphs_path, "Get save graphs path.")
.def("set_save_graphs_path", &mindspore::MsContext::set_save_graphs_path, "Set save graphs path.")
.def("get_enable_dump", &mindspore::MsContext::enable_dump, "Get whether to enable dump.")
.def("set_enable_dump", &mindspore::MsContext::set_enable_dump, "Set whether to enable dump.")
.def("get_save_dump_path", &mindspore::MsContext::save_dump_path, "Get path to dump.")
.def("set_save_dump_path", &mindspore::MsContext::set_save_dump_path, "Set path to dump.")
.def("set_graph_memory_max_size", &mindspore::MsContext::set_graph_memory_max_size, "set graph memory max size.")
.def("set_variable_memory_max_size", &mindspore::MsContext::set_variable_memory_max_size,
"set variable memory max size")
.def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.")
.def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.")
.def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.")
.def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.")
.def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.")
.def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.")
.def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.")
.def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.")
.def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.")
.def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel,
"Set the GraphKernel switch to on or off.")
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.")
.def("get_enable_sparse", &mindspore::MsContext::enable_sparse, "Get whether to enable sparsity.")
.def("set_enable_sparse", &mindspore::MsContext::set_enable_sparse, "Set whether to enable sparsity.");
.def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.");
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")

View File

@ -271,7 +271,7 @@ void InitOpt(const ResourcePtr &res) {
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) {
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
g_pass_opts["opt_graph_kernel_a"]->set_enable(false);
g_pass_opts["opt_graph_kernel_b"]->set_enable(false);
}

View File

@ -88,7 +88,7 @@ std::string GetBaseNameForIR(int stage_idx, const std::string &action_name) {
if (ms_context == nullptr) {
MS_LOG(EXCEPTION) << "ms_context is nullptr";
}
auto save_graphs_path = ms_context->save_graphs_path();
auto save_graphs_path = ms_context->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
@ -646,7 +646,7 @@ void Pipeline::Run() {
if (!result) {
MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first;
}
if (MsContext::GetInstance()->save_graphs_flag() && resource_->func_graph() != nullptr) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && resource_->func_graph() != nullptr) {
auto graph = resource_->func_graph();
if (graph != nullptr) {
user_graph = graph;
@ -688,7 +688,7 @@ void Pipeline::Run() {
MsProfile::Reset();
#endif
if (MsContext::GetInstance()->save_graphs_flag() && (user_graph != nullptr)) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && (user_graph != nullptr)) {
std::string user_graph_file = GetFilePathName("ModelDigraph.dot");
MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file;
draw::DrawUserFuncGraph(user_graph_file, user_graph);
@ -710,7 +710,7 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef
if (!succ) {
MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed.";
}
if (MsContext::GetInstance()->execution_mode() == 0 && !converted->isa<tensor::Tensor>()) {
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == 0 && !converted->isa<tensor::Tensor>()) {
MS_EXCEPTION(TypeError) << "For 'graph mode', the " << i << "th arg: " << converted->ToString()
<< " is not tensor.";
}
@ -891,7 +891,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
// Convert CNodeList to LinConvertResult.
ConfigManager::GetInstance().set_iter_num(1);
auto runner = convert_fn({app_init}, "");
if (MsContext::GetInstance()->execution_mode() != kPynativeMode) {
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
backend->Link(runner.graph_id);
}
ConfigManager::GetInstance().set_iter_num(size);
@ -965,10 +965,11 @@ void InitHccl() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
(void)context::OpenTsd(ms_context);
uint32_t device_id = ms_context->device_id();
std::string device_name = ms_context->device_target();
ms_context->set_enable_hccl(true);
if (ms_context->backend_policy() == "ms" && ms_context->device_target() == kAscendDevice) {
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
ms_context->set_param<bool>(MS_CTX_ENABLE_HCCL, true);
if (ms_context->backend_policy() == "ms" &&
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance);
if (!runtime_instance->Init()) {

View File

@ -214,7 +214,7 @@ bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::di
return false;
}
if (MsContext::GetInstance()->save_graphs_flag()) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
convertor.DrawComputeGraph(GetFilePathName("ge_graph.dot")); // for debug
convertor.DrawInitGraph(GetFilePathName("init_graph.dot")); // for debug
convertor.DrawSaveCheckpointGraph(GetFilePathName("save_checkpoint_graph.dot")); // for debug
@ -244,7 +244,7 @@ FuncGraphPtr BuildDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, co
}
FuncGraphPtr anf_graph = info.at(phase)->func_graph;
if (MsContext::GetInstance()->save_graphs_flag()) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
draw::Draw(GetFilePathName("anf_graph.dot"), anf_graph); // for debug
DumpIR(GetFilePathName("anf_graph.ir"), anf_graph, true);
}

View File

@ -118,8 +118,9 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
<< ", current function call depth: " << engine->function_call_depth();
AbstractBasePtr ret_base = nullptr;
engine->IncreaseFunctionCallDepth();
if (engine->function_call_depth() > MsContext::GetInstance()->max_call_depth()) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit " << MsContext::GetInstance()->max_call_depth() << ".";
if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) << ".";
}
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
@ -409,7 +410,7 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
bparams.push_back(SensitivityTransform(orig_func_));
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
[&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr {
if (enable_sparse && arg_spec->isa<AbstractTensor>()) {

View File

@ -62,7 +62,7 @@ class Evaluator : public Base {
virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (!enable_sparse) {
return nullptr;
}

View File

@ -290,7 +290,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
if (abs_base->isa<AbstractTensor>()) {
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
dic["shape"] = arg_tensor->shape()->shape();
if (MsContext::GetInstance()->execution_mode() == kGraphMode) {
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
const auto &min_shape = arg_tensor->shape()->min_shape();
const auto &max_shape = arg_tensor->shape()->max_shape();
if (!min_shape.empty() && !max_shape.empty()) {

View File

@ -558,8 +558,8 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
MS_EXCEPTION_IF_NULL(op_exec_info);
MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms";
auto ms_context = MsContext::GetInstance();
ms_context->set_enable_pynative_infer(true);
std::string device_target = ms_context->device_target();
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_target != kAscendDevice && device_target != kGPUDevice) {
MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode";
}
@ -567,7 +567,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
if (session == nullptr) {
session = session::SessionFactory::Get().Create(device_target);
MS_EXCEPTION_IF_NULL(session);
session->Init(ms_context->device_id());
session->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID));
}
std::vector<tensor::TensorPtr> input_tensors;
@ -578,7 +578,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
session->BuildOpAsync(op_exec_info.get(), graph_info, input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, &input_tensors);
py::tuple result = session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors);
ms_context->set_enable_pynative_infer(false);
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
*status = PYNATIVE_SUCCESS;
MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms";
return result;
@ -1308,7 +1308,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
// Maybe exit in the pynative runing op, so need reset pynative flag.
auto ms_context = MsContext::GetInstance();
if (ms_context != nullptr) {
ms_context->set_enable_pynative_infer(false);
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
}
ConfigManager::GetInstance().ResetIterNum();
return;

View File

@ -89,7 +89,7 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
MS_EXCEPTION(ValueError) << "For user define net bprop, the gradients number: " << grads.size()
<< " is not equal to the args number: " << py_args.size() - 2 << ".";
}
if (MsContext::GetInstance()->check_bprop_flag()) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG)) {
for (size_t i = 0; i < grads.size(); i++) {
if (py::isinstance<tensor::Tensor>(py_args[i])) {
if (!py::isinstance<tensor::Tensor>(grads[i])) {

View File

@ -154,7 +154,7 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s
DeviceAddressPtr AssignLaunchMemory(size_t size, const std::string &format, TypeId type) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device_id = ms_context->device_id();
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance);
auto address_ptr = runtime_instance->AssignSingleOpLaunchMemory(size, format, type);
@ -261,11 +261,12 @@ void AscendDeviceAddress::SyncStream() const {
MS_LOG(INFO) << "Start!";
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() != kPynativeMode && !ms_context->enable_pynative_infer()) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
MS_LOG(INFO) << "Finish!";
return;
}
auto device_id = ms_context->device_id();
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance);
auto ret = runtime_instance->SyncStream();
@ -348,7 +349,7 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device_id = ms_context->device_id();
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance);
auto ret =
@ -475,7 +476,8 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh
std::vector<size_t> device_shape = GetDeviceShape(&host_shape);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() != kGraphMode && ms_context->execution_mode() != kPynativeMode &&
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode &&
ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
type_id_name_map.find(type_id_) != type_id_name_map.end()) {
std::pair<std::string, std::string> type_format = std::make_pair(type_id_name_map.at(type_id_), format_);
if (use_trans_data.find(type_format) != use_trans_data.end()) {

View File

@ -158,7 +158,7 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std
bool AscendKernelRuntime::NeedDestroyHccl() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->enable_hccl()) {
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
MS_LOG(INFO) << "Hccl is not enabled";
return false;
}
@ -177,7 +177,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto ret = rtSetDevice(context_ptr->device_id());
auto ret = rtSetDevice(context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID));
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast<int>(ret) << "]";
}
@ -461,12 +461,12 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id();
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->enable_task_sink();
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (!is_task_sink) {
return true;
}
#ifdef MEM_REUSE_DEBUG
if (!context_ptr->enable_mem_reuse()) {
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_REUSE)) {
// Get normal graph ir for memreuse
mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph);
}
@ -518,7 +518,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id();
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->enable_task_sink();
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (!is_task_sink) {
return true;
}
@ -658,7 +658,7 @@ bool AscendKernelRuntime::InitDevice() {
MS_LOG(ERROR) << "Get MsContext instance failed";
return false;
}
if (context_ptr->enable_hccl()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
if (!HcclInit()) {
MS_LOG(ERROR) << "HcclInit init failed";
return false;
@ -746,7 +746,7 @@ bool AscendKernelRuntime::DestroyHccl() {
return false;
}
MS_LOG(INFO) << "Hccl destroy successful, status = " << res << ".";
context_ptr->set_enable_hccl(false);
context_ptr->set_param<bool>(MS_CTX_ENABLE_HCCL, false);
return true;
}

View File

@ -43,7 +43,7 @@ void AscendMemoryManager::MallocDeviceMemory() {
uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
auto variable_memory_max_size = context->variable_memory_max_size();
auto variable_memory_max_size = context->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
if (variable_memory_max_size == "0") {
return 0;
}

View File

@ -1373,7 +1373,7 @@ vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::it
bool AscendStreamAssign::IsTaskSink() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (!ms_context->enable_task_sink()) {
if (!ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
MS_LOG(INFO) << "Task sink mode is not enable";
return false;
} else {

View File

@ -117,7 +117,7 @@ void DataDumper::SetOpMappingInfo(NotNull<aicpu::dump::OpMappingInfo *> dump_inf
if (!dump_path.has_value()) {
MS_LOG(EXCEPTION) << "Dump path invalid";
}
auto device_id = context_ptr->device_id();
auto device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
dump_info->set_dump_path("/" + dump_path.value() + "_" + std::to_string(device_id) + "/");
MS_LOG(INFO) << "[DataDump] dump_path:" << dump_path.value();

View File

@ -363,7 +363,7 @@ void PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index,
*precision_reduce = false;
return;
}
if (context_ptr->enable_reduce_precision()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION)) {
selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
kernel_support_datatype, &kernel_match_datatype_idx_copy);
}

View File

@ -117,7 +117,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
}
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
const string prof_options_str = context->profiling_options();
const string prof_options_str = context->get_param<std::string>(MS_CTX_PROFILING_OPTIONS);
std::vector<string> opts = Split(prof_options_str, ':');
if (opts.empty()) {
MS_LOG(WARNING) << "Profiling is enabled, but profiling option is not set!";

View File

@ -41,7 +41,7 @@ class ProfilingManager {
inline bool IsProfiling() const {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
return context->enable_profiling();
return context->get_param<bool>(MS_CTX_ENABLE_PROFILING);
}
protected:

View File

@ -342,12 +342,12 @@ void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids,
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
TaskDescReporter task_reporter(context->device_id(), "vm.task_desc_info", ret->second);
TaskDescReporter task_reporter(context->get_param<uint32_t>(MS_CTX_DEVICE_ID), "vm.task_desc_info", ret->second);
task_reporter.set_task_ids(task_ids);
task_reporter.set_stream_ids(stream_ids);
task_reporter.ReportData();
GraphDescReporter graph_reporter(context->device_id(), "vm.graph_desc_info", ret->second);
GraphDescReporter graph_reporter(context->get_param<uint32_t>(MS_CTX_DEVICE_ID), "vm.graph_desc_info", ret->second);
graph_profiling_cnode_.erase(ret);
graph_reporter.ReportData();
@ -357,7 +357,7 @@ void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids,
MS_LOG(ERROR) << "Graph id not found in graph_point";
return;
}
PointReporter point_reporter(context->device_id(), "vm.point");
PointReporter point_reporter(context->get_param<uint32_t>(MS_CTX_DEVICE_ID), "vm.point");
for (const auto &point : point_iter->second) {
point_reporter.AddReportData(point);
}

View File

@ -416,7 +416,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
mem_manager_->ResetDynamicMemory();
AssignStaticMemoryInput(graph);
AssignStaticMemoryValueNode(graph);
bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool();
bool is_enable_dynamic_mem = context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL);
if (is_enable_dynamic_mem) {
// Use the dynamic memory pool.
InitKernelRefCount(graph);
@ -435,8 +435,8 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
bool ret = true;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool();
bool is_enable_pynative_infer = context_ptr->enable_pynative_infer();
bool is_enable_dynamic_mem = context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL);
bool is_enable_pynative_infer = context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
if (is_enable_dynamic_mem && !is_enable_pynative_infer) {
auto graph_id = graph->graph_id();
auto iter = mem_swap_map_.find(graph_id);

View File

@ -29,7 +29,7 @@ bool GPUMemoryAllocator::Init() {
size_t free_size = CudaDriver::free_mem_size();
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
limited_device_memory_ = context_ptr->max_device_memory();
limited_device_memory_ = context_ptr->get_param<float>(MS_CTX_MAX_DEVICE_MEMORY);
available_device_memory_ = FloatToSize(limited_device_memory_ * 1024 * 1024 * 1024);
if (total_size > 0 && free_size > 0 && available_device_memory_ > 0) {
MS_LOG(INFO) << "GPU device total memory size " << total_size << ", current free memory size " << free_size
@ -44,7 +44,7 @@ bool GPUMemoryAllocator::Init() {
void GPUMemoryAllocator::CheckMaxDeviceMemory() const {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto max_device_memory = context_ptr->max_device_memory();
auto max_device_memory = context_ptr->get_param<float>(MS_CTX_MAX_DEVICE_MEMORY);
// Currently not support modifying the max device memory.
if (limited_device_memory_ != max_device_memory) {
MS_LOG(EXCEPTION)

View File

@ -37,7 +37,7 @@ void GPUMemoryManager::MallocDeviceMemory() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
// If use the dynamic memory pool, then alloc the first memory block to init.
if (context_ptr->enable_dynamic_mem_pool()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) {
auto device_addr = MallocMemFromMemPool(1);
if (!device_addr) {
MS_LOG(EXCEPTION) << "Dynamic memory pool init error.";
@ -65,7 +65,7 @@ void GPUMemoryManager::FreeDeviceMemory() {
uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_dynamic_mem_pool()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) {
auto device_ptr = MallocMemFromMemPool(size);
MS_EXCEPTION_IF_NULL(device_ptr);
return AddressOffset(device_ptr, 0);

View File

@ -162,7 +162,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return false;
}
if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) {

View File

@ -60,8 +60,8 @@ void KernelAdjust::ReorderGetNext(const std::shared_ptr<session::KernelGraph> &k
bool KernelAdjust::NeedInsertSwitch() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
return (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() &&
ConfigManager::GetInstance().iter_num() > 1);
return (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) &&
context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && ConfigManager::GetInstance().iter_num() > 1);
}
CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr,

View File

@ -50,7 +50,7 @@ bool KernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
#endif
bool is_task_sink = context_ptr->enable_task_sink();
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (is_task_sink) {
ret = RunTask(graph);
} else {
@ -502,7 +502,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
MS_LOG(INFO) << "communication op addr exist";
continue;
}
if (context_ptr->enable_hccl()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
mem_size = mem_manager_->GetCommonAlignSize(mem_size);
}
total_size += mem_size;
@ -646,7 +646,8 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
DeviceAddressPtr address = nullptr;
address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id);
MS_EXCEPTION_IF_NULL(address);
if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, node_size)) {
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
!mem_manager_->MallocMemFromMemPool(address, node_size)) {
MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << node_size;
} else if (mem_manager_->MallocMem(kStaticMem, node_size, address) == nullptr) {
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size;
@ -682,7 +683,8 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
DeviceAddressPtr address = nullptr;
address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
MS_EXCEPTION_IF_NULL(address);
if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
!mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << tensor_size;
} else if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
@ -701,7 +703,7 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(mem_manager_);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_enable_mem_reuse = context_ptr->enable_mem_reuse();
bool is_enable_mem_reuse = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_REUSE);
auto mem_type = kDynamicMem;
if (is_enable_mem_reuse) {
mem_manager_->MallocReusedDynamicMem(graph);

View File

@ -54,7 +54,7 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, Me
uint8_t *ptr = nullptr;
if (AnfAlgo::IsCommunicationOp(node)) {
bool communication_mem = false;
if (context_ptr->enable_hccl()) {
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
communication_mem = true;
}
if (type == kStaticMem) {

View File

@ -1070,7 +1070,7 @@ void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNod
convertor.inputs_ = inputs;
(void)convertor.ConvertAllNode().BuildGraph();
std::string name = graph_node->ToString() + "_ge_graph.dot";
if (MsContext::GetInstance()->save_graphs_flag()) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
convertor.DrawComputeGraph(name);
}
branches_map_[node.get()] = *(convertor.df_graph_);

View File

@ -41,13 +41,13 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
MS_LOG(EXCEPTION) << "nullptr";
}
if (ms_context_ptr->is_pynative_ge_init()) {
if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
return true;
}
if (ms_context_ptr->tsd_ref()) {
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
MS_LOG(DEBUG) << "TDT Dataset client is already opened.";
ms_context_ptr->set_tsd_ref("++");
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
return true;
}
@ -59,7 +59,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
unsigned int device_id;
unsigned int rank_size = 1;
device_id = ms_context_ptr->device_id();
device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto rank_size_env = common::GetEnv("RANK_SIZE");
if (rank_size_env.empty()) {
@ -79,7 +79,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << ".";
return false;
}
ms_context_ptr->set_tsd_ref("++");
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
#ifdef ENABLE_TDTQUE
int32_t initStatus = tdt::TdtHostInit(device_id);
if (initStatus != TDT_OK_CODE) {
@ -88,7 +88,8 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
}
ms_context_ptr->tdt_print_ = std::thread(TensorPrint());
#endif
MS_LOG(INFO) << "Open and init tsd successful, tsd reference = " << ms_context_ptr->tsd_ref() << ".";
MS_LOG(INFO) << "Open and init tsd successful, tsd reference = "
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
return true;
}
@ -96,12 +97,12 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "nullptr";
}
if (ms_context_ptr->tsd_ref() == 0) {
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
return true;
}
ms_context_ptr->set_tsd_ref("--");
if (force || ms_context_ptr->tsd_ref() == 0) {
ms_context_ptr->set_tsd_ref(" ");
ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF);
if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
#ifdef ENABLE_TDTQUE
int32_t stopStatus = tdt::TdtHostStop(KNpuLog);
if (stopStatus != TDT_OK_CODE) {
@ -123,17 +124,17 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
MS_LOG(ERROR) << "tdt thread join failed: " << e.what();
}
#endif
auto device_id = ms_context_ptr->device_id();
auto device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
TDT_StatusT status = TsdClose(device_id);
if (status != TDT_OK) {
MS_LOG(EXCEPTION) << "Close tsd failed, status = " << status << ".";
return false;
}
ms_context_ptr->set_pynative_ge_init(false);
ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
MS_LOG(INFO) << "Destroy and close tsd successful, status = " << status << ".";
} else {
MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = " << ms_context_ptr->tsd_ref()
<< ".";
MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = "
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
}
return true;
@ -159,14 +160,14 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
}
#ifdef ENABLE_GE
(*ge_options)["device_id"] = "0";
(*ge_options)["ge.exec.enableDump"] = std::to_string(ms_context_ptr->enable_dump());
(*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->save_dump_path();
(*ge_options)["ge.exec.enableDump"] = std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP));
(*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH);
(*ge_options)["ge.exec.dumpMode"] = "output";
MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->enable_dump())
<< " and save dump path is " << ms_context_ptr->save_dump_path() << ".";
(*ge_options)["ge.exec.profilingMode"] = std::to_string(ms_context_ptr->enable_profiling());
if (ms_context_ptr->enable_profiling()) {
(*ge_options)["ge.exec.profilingOptions"] = ms_context_ptr->profiling_options();
MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP))
<< " and save dump path is " << ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH) << ".";
(*ge_options)["ge.exec.profilingMode"] = std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_PROFILING));
if (ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_PROFILING)) {
(*ge_options)["ge.exec.profilingOptions"] = ms_context_ptr->get_param<std::string>(MS_CTX_PROFILING_OPTIONS);
}
(*ge_options)["rank_table_file"] = "";
@ -178,12 +179,12 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
}
(*ge_options)["graphType"] = "1";
if (ms_context_ptr->graph_memory_max_size() != "0") {
(*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->graph_memory_max_size();
if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
(*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
}
if (ms_context_ptr->variable_memory_max_size() != "0") {
(*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->variable_memory_max_size();
if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
(*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
}
#if ENABLE_TRAIN == 1
@ -224,7 +225,7 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
}
// Enable auto mixed precision according to the context options
if (ms_context_ptr->auto_mixed_precision_flag()) {
if (ms_context_ptr->get_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG)) {
(*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision";
} else {
(*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
@ -240,7 +241,7 @@ void SetHcclOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<s
}
auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
auto env_rank_id = common::GetEnv("RANK_ID");
auto env_device_id = std::to_string(ms_context_ptr->device_id());
auto env_device_id = std::to_string(ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID));
if (!(env_table_file.empty() || env_rank_id.empty())) {
MS_LOG(INFO) << "Initialize Ge for distribute parameter";
MS_LOG(INFO) << "Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
@ -275,12 +276,12 @@ bool InitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
MS_LOG(EXCEPTION) << "nullptr";
}
#ifdef ENABLE_GE
if (ms_context_ptr->is_pynative_ge_init()) {
if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
return true;
}
if (ms_context_ptr->ge_ref()) {
ms_context_ptr->set_ge_ref("++");
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF)) {
ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
return true;
}
@ -293,8 +294,8 @@ bool InitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
MS_LOG(EXCEPTION) << "Initialize GE failed!";
}
}
ms_context_ptr->set_ge_ref("++");
MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->ge_ref() << ".";
ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
#endif
return true;
}
@ -303,12 +304,13 @@ bool PynativeInitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "nullptr";
}
if (ms_context_ptr->is_pynative_ge_init() || ms_context_ptr->ge_ref() || ms_context_ptr->tsd_ref()) {
if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT) ||
ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
return true;
}
(void)OpenTsd(ms_context_ptr);
(void)InitGe(ms_context_ptr);
ms_context_ptr->set_pynative_ge_init(true);
ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, true);
return true;
}
@ -317,12 +319,12 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
MS_LOG(EXCEPTION) << "nullptr";
}
#ifdef ENABLE_GE
if (ms_context_ptr->ge_ref() == 0) {
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
return true;
}
ms_context_ptr->set_ge_ref("--");
if (force || ms_context_ptr->ge_ref() == 0) {
ms_context_ptr->set_ge_ref(" ");
ms_context_ptr->decrease_param<uint32_t>(MS_CTX_GE_REF);
if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
ms_context_ptr->set_param<uint32_t>(MS_CTX_GE_REF, 0);
try {
DfGraphManager::GetInstance().DeleteGraphRunner();
DfGraphManager::GetInstance().DeleteGeSession();
@ -337,7 +339,8 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
}
ms_context_ptr->set_pynative_ge_init(false);
} else {
MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = " << ms_context_ptr->ge_ref() << ".";
MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = "
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
}
#endif
return true;
@ -347,14 +350,14 @@ bool IsTsdOpened(const std::shared_ptr<MsContext> &ms_context_ptr) {
if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "nullptr";
}
return ms_context_ptr->IsTsdOpened();
return ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) > 0;
}
bool IsGeInited(const std::shared_ptr<MsContext> &ms_context_ptr) {
if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "nullptr";
}
return ms_context_ptr->IsGeInited();
return ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) > 0;
}
// Register for device type.

View File

@ -353,7 +353,7 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
// When sparse enabled, the undetermined might be raised and eliminated in opt passes
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (enable_sparse) {
return std::make_shared<abstract::AbstractUndetermined>();
}

View File

@ -273,7 +273,7 @@ void TensorPrint::operator()() {
prntpb::Print print;
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string print_file_path = ms_context->print_file_path();
std::string print_file_path = ms_context->get_param<std::string>(MS_CTX_PRINT_FILE_PATH);
if (print_file_path == "") {
while (true) {
std::vector<tdt::DataItem> bundle;

View File

@ -59,7 +59,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
graph_id = target_sess_->CompileGraphAsync(lst, outputs);
}
if (MsContext::GetInstance()->precompile_only()) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
return result;
}
@ -180,7 +180,7 @@ void MsBackend::CreateOtherSession(const std::string &target) {
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
uint32_t device_id = context_ptr->device_id();
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
other_sess_->Init(device_id);
other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
other_device_ = target;

View File

@ -56,7 +56,7 @@ namespace {
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string last_target = context_ptr->device_target();
std::string last_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
for (auto &node : nodes) {
if (node->isa<CNode>()) {
std::string cur_target = GetCNodeTarget(node);
@ -348,7 +348,7 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
if (prim->name() == prim::kPrimBpropCut->name()) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_enable_pynative_hook(true);
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, true);
}
if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) {
@ -412,7 +412,7 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
if (ContainMultiTarget(nodes)) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->device_target();
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
nodes = SplitSort(graph, default_target);
return SplitNodesWithTarget(nodes, graph);
}
@ -920,17 +920,17 @@ BackendPtr CreateBackend() {
}
if (name == kMsConvert) {
std::string target = context_ptr->device_target();
uint32_t device_id = context_ptr->device_id();
std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto backend = std::make_shared<MsBackend>(name, target, device_id);
std::string device_target = MsContext::GetInstance()->device_target();
std::string device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_target == kAscendDevice) {
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
backend->set_is_multi_graph_sink(false);
context_ptr->set_is_multi_graph_sink(false);
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
} else {
backend->set_is_multi_graph_sink(true);
context_ptr->set_is_multi_graph_sink(true);
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
}
}
return backend;

View File

@ -22,7 +22,7 @@ import threading
from collections import namedtuple
from types import FunctionType
from mindspore import log as logger
from mindspore._c_expression import MSContext
from mindspore._c_expression import MSContext, ms_ctx_param, ms_ctx_get_param, ms_ctx_set_param
from mindspore._checkparam import args_type_check
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
_reset_auto_parallel_context
@ -157,9 +157,15 @@ class _Context:
raise ValueError("Context handle is none in context!!!")
return value
def get_param(self, param):
return ms_ctx_get_param(self._context_handle, param)
def set_param(self, param, value):
ms_ctx_set_param(self._context_handle, param, value)
@property
def mode(self):
return self._context_handle.get_execution_mode()
return self.get_param(ms_ctx_param.execution_mode)
@mode.setter
def mode(self, mode):
@ -169,15 +175,17 @@ class _Context:
Args:
mode (int): GRAPH_MODE or PYNATIVE_MODE.
"""
self._context_handle.set_execution_mode(mode)
if mode == PYNATIVE_MODE:
if self.enable_debug_runtime:
self.set_backend_policy("vm")
self._context_switches.push(True, None)
else:
elif mode == GRAPH_MODE:
if self.enable_debug_runtime:
self.set_backend_policy("ge")
self._context_switches.push(False, None)
else:
raise ValueError(f'The execution mode {mode} is invalid!')
self.set_param(ms_ctx_param.execution_mode, mode)
def set_backend_policy(self, policy):
success = self._context_handle.set_backend_policy(policy)
@ -186,110 +194,106 @@ class _Context:
@property
def precompile_only(self):
return self._context_handle.get_precompile_only()
return self.get_param(ms_ctx_param.precompile_only)
@precompile_only.setter
def precompile_only(self, precompile_only):
self._context_handle.set_precompile_only(precompile_only)
self.set_param(ms_ctx_param.precompile_only, precompile_only)
@property
def save_graphs(self):
return self._context_handle.get_save_graphs_flag()
return self.get_param(ms_ctx_param.save_graphs_flag)
@save_graphs.setter
def save_graphs(self, save_graphs_flag):
self._context_handle.set_save_graphs_flag(save_graphs_flag)
self.set_param(ms_ctx_param.save_graphs_flag, save_graphs_flag)
@property
def save_graphs_path(self):
return self._context_handle.get_save_graphs_path()
return self.get_param(ms_ctx_param.save_graphs_path)
@save_graphs_path.setter
def save_graphs_path(self, save_graphs_path):
self._context_handle.set_save_graphs_path(
_make_directory(save_graphs_path))
self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path))
@property
def device_target(self):
return self._context_handle.get_device_target()
return self.get_param(ms_ctx_param.device_target)
@device_target.setter
def device_target(self, target):
success = self._context_handle.set_device_target(target)
if not success:
raise ValueError("Target device name is invalid!!!")
if self.enable_debug_runtime and self.device_target == "CPU":
valid_targets = ["CPU", "GPU", "Ascend", "Davinci"]
if not target in valid_targets:
raise ValueError(f"Target device name {target} is invalid! It must be one of {valid_targets}")
if target == "Davinci":
target = "Ascend"
self.set_param(ms_ctx_param.device_target, target)
if self.enable_debug_runtime and target == "CPU":
self.set_backend_policy("vm")
@property
def device_id(self):
return self._context_handle.get_device_id()
return self.get_param(ms_ctx_param.device_id)
@device_id.setter
def device_id(self, device_id):
if device_id < 0 or device_id > 4095:
raise ValueError(
"Device id must be in [0, 4095], but got {}".format(device_id))
success = self._context_handle.set_device_id(device_id)
if not success:
raise RuntimeError("Device id set failed!!!")
raise ValueError(f"Device id must be in [0, 4095], but got {device_id}")
self.set_param(ms_ctx_param.device_id, device_id)
@property
def max_call_depth(self):
return self._context_handle.get_max_call_depth()
return self.get_param(ms_ctx_param.max_call_depth)
@max_call_depth.setter
def max_call_depth(self, max_call_depth):
if max_call_depth <= 0:
raise ValueError(
"Max call depth must be greater than 0, but got {}".format(max_call_depth))
self._context_handle.set_max_call_depth(max_call_depth)
raise ValueError(f"Max call depth must be greater than 0, but got {max_call_depth}")
self.set_param(ms_ctx_param.max_call_depth, max_call_depth)
@property
def enable_auto_mixed_precision(self):
return self._context_handle.get_auto_mixed_precision_flag()
return self.get_param(ms_ctx_param.auto_mixed_precision_flag)
@enable_auto_mixed_precision.setter
def enable_auto_mixed_precision(self, enable_auto_mixed_precision):
self._context_handle.set_auto_mixed_precision_flag(
enable_auto_mixed_precision)
self.set_param(ms_ctx_param.auto_mixed_precision_flag, enable_auto_mixed_precision)
@property
def enable_reduce_precision(self):
return self._context_handle.get_enable_reduce_precision_flag()
return self.get_param(ms_ctx_param.enable_reduce_precision_flag)
@enable_reduce_precision.setter
def enable_reduce_precision(self, enable_reduce_precision):
self._context_handle.set_enable_reduce_precision_flag(
enable_reduce_precision)
self.set_param(ms_ctx_param.enable_reduce_precision_flag, enable_reduce_precision)
@property
def enable_dump(self):
return self._context_handle.get_enable_dump()
return self.get_param(ms_ctx_param.enable_dump)
@enable_dump.setter
def enable_dump(self, enable_dump):
self._context_handle.set_enable_dump(enable_dump)
self.set_param(ms_ctx_param.enable_dump, enable_dump)
@property
def save_dump_path(self):
return self._context_handle.get_save_dump_path()
return self.get_param(ms_ctx_param.save_dump_path)
@save_dump_path.setter
def save_dump_path(self, save_dump_path):
self._context_handle.set_save_dump_path(save_dump_path)
self.set_param(ms_ctx_param.save_dump_path, save_dump_path)
@property
def enable_profiling(self):
return self._context_handle.get_enable_profiling()
return self.get_param(ms_ctx_param.enable_profiling)
@enable_profiling.setter
def enable_profiling(self, flag):
self._context_handle.set_enable_profiling(flag)
self.set_param(ms_ctx_param.enable_profiling, flag)
@property
def profiling_options(self):
return self._context_handle.get_profiling_options()
return self.get_param(ms_ctx_param.profiling_options)
@profiling_options.setter
def profiling_options(self, option):
@ -298,15 +302,15 @@ class _Context:
if option not in options:
raise ValueError("Profiling options must be in 'training_trace' 'task_trace' "
"'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'.")
self._context_handle.set_profiling_options(option)
self.set_param(ms_ctx_param.profiling_options, option)
@property
def enable_graph_kernel(self):
return self._context_handle.get_enable_graph_kernel()
return self.get_param(ms_ctx_param.enable_graph_kernel)
@enable_graph_kernel.setter
def enable_graph_kernel(self, graph_kernel_switch_):
self._context_handle.set_enable_graph_kernel(graph_kernel_switch_)
self.set_param(ms_ctx_param.enable_graph_kernel, graph_kernel_switch_)
@property
def reserve_class_name_in_scope(self):
@ -325,20 +329,14 @@ class _Context:
@variable_memory_max_size.setter
def variable_memory_max_size(self, variable_memory_max_size):
if not check_input_format(variable_memory_max_size):
raise ValueError(
"Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
raise ValueError(
"Context param variable_memory_max_size should be less than 31GB.")
variable_memory_max_size_ = variable_memory_max_size[:-
2] + " * 1024 * 1024 * 1024"
graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - \
int(variable_memory_max_size[:-2])
graph_memory_max_size_ = str(
graph_memory_max_size) + " * 1024 * 1024 * 1024"
self._context_handle.set_variable_memory_max_size(
variable_memory_max_size_)
self._context_handle.set_graph_memory_max_size(graph_memory_max_size_)
raise ValueError("Context param variable_memory_max_size should be less than 31GB.")
variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2])
graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024"
self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_)
self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_)
@property
def enable_ge(self):
@ -355,15 +353,15 @@ class _Context:
@property
def check_bprop(self):
return self._context_handle.get_check_bprop_flag()
return self.get_param(ms_ctx_param.check_bprop_flag)
@check_bprop.setter
def check_bprop(self, check_bprop_flag):
self._context_handle.set_check_bprop_flag(check_bprop_flag)
self.set_param(ms_ctx_param.check_bprop_flag, check_bprop_flag)
@property
def max_device_memory(self):
return self._context_handle.get_max_device_memory()
return self.get_param(ms_ctx_param.max_device_memory)
@max_device_memory.setter
def max_device_memory(self, max_device_memory):
@ -372,7 +370,7 @@ class _Context:
max_device_memory_value = float(max_device_memory[:-2])
if max_device_memory_value == 0:
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
self._context_handle.set_max_device_memory(max_device_memory_value)
self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value)
@property
def print_file_path(self):
@ -392,15 +390,15 @@ class _Context:
full_file_name = os.path.join(path, file_name)
else:
full_file_name = print_file_path
self._context_handle.set_print_file_path(full_file_name)
self.set_param(ms_ctx_param.print_file_path, full_file_name)
@property
def enable_sparse(self):
return self._context_handle.get_enable_sparse()
return self.get_param(ms_ctx_param.enable_sparse)
@enable_sparse.setter
def enable_sparse(self, enable_sparse):
self._context_handle.set_enable_sparse(enable_sparse)
self.set_param(ms_ctx_param.enable_sparse, enable_sparse)
def check_input_format(x):
import re
@ -486,8 +484,6 @@ def set_auto_parallel_context(**kwargs):
full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in
data parallel training in the benefit of time and memory saving.
max_call_depth(int): Specify the function call depth limit. Default: 1000.
Raises:
ValueError: If input key is not attribute in auto parallel context.
@ -501,7 +497,6 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(max_call_depth=80)
"""
_set_auto_parallel_context(**kwargs)
@ -603,6 +598,7 @@ def set_context(**kwargs):
a file by default, and turn off printing to the screen. If the file already exists, add a timestamp
suffix to the file.
enable_sparse (bool): Whether to enable sparsity feature. Default: False.
max_call_depth(int): Specify the function call depth limit. Default: 1000.
Raises:
ValueError: If input key is not an attribute in context.
@ -623,6 +619,7 @@ def set_context(**kwargs):
>>> context.set_context(enable_profiling=True, profiling_options="training_trace")
>>> context.set_context(max_device_memory="3.5GB")
>>> context.set_context(print_file_path="print.pb")
>>> context.set_context(max_call_depth=80)
"""
for key, value in kwargs.items():
if not hasattr(_context(), key):

View File

@ -51,7 +51,7 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (enable_sparse && dflt->isa<AbstractTensor>()) {
auto dflt_tensor = dflt->cast<AbstractTensorPtr>();
return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone());

View File

@ -232,7 +232,7 @@ std::string GetMaketupleNodeTarget(const CNodePtr &cnode) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->device_target();
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
return default_target;
}
@ -248,7 +248,7 @@ std::string GetTupleGetItemTarget(const CNodePtr &cnode, const PrimitivePtr &pri
std::string GetCNodeTarget(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->device_target();
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (!node->isa<CNode>()) {
return default_target;
}

View File

@ -652,7 +652,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
new_func_graph->set_param_default_value(item.first, cloner[item.second]);
}
if (MsContext::GetInstance()->is_multi_graph_sink()) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
}

View File

@ -30,7 +30,7 @@ abstract::AbstractBasePtr MetaFuncGraph::ToAbstract() {
FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (!enable_sparse) {
return nullptr;
}

View File

@ -32,49 +32,50 @@ std::map<std::string, MsBackendPolicy> MsContext::policy_map_ = {{"ge", kMsBacke
{"vm_prior", kMsBackendVmPrior}};
MsContext::MsContext(const std::string &policy, const std::string &target) {
save_graphs_flag_ = false;
save_graphs_path_ = ".";
enable_dump_ = false;
save_dump_path_ = ".";
tsd_ref_ = 0;
ge_ref_ = 0;
is_multi_graph_sink_ = false;
is_pynative_ge_init_ = false;
enable_reduce_precision_ = true;
set_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG, false);
set_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH, ".");
set_param<std::string>(MS_CTX_SAVE_DUMP_PATH, ".");
set_param<uint32_t>(MS_CTX_TSD_REF, 0);
set_param<uint32_t>(MS_CTX_GE_REF, 0);
set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
set_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION, true);
auto env_device = common::GetEnv("DEVICE_ID");
if (!env_device.empty()) {
device_id_ = UlongToUint(std::stoul(env_device.c_str()));
uint32_t device_id = UlongToUint(std::stoul(env_device.c_str()));
set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
} else {
device_id_ = 0;
set_param<uint32_t>(MS_CTX_DEVICE_ID, 0);
}
max_call_depth_ = MAX_CALL_DEPTH_DEFAULT;
backend_policy_ = policy_map_[policy];
device_target_ = target;
execution_mode_ = kPynativeMode;
enable_task_sink_ = true;
ir_fusion_flag_ = true;
enable_hccl_ = false;
set_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH, MAX_CALL_DEPTH_DEFAULT);
set_param<std::string>(MS_CTX_DEVICE_TARGET, target);
set_param<int>(MS_CTX_EXECUTION_MODE, kPynativeMode);
set_param<bool>(MS_CTX_ENABLE_TASK_SINK, true);
set_param<bool>(MS_CTX_IR_FUSION_FLAG, true);
set_param<bool>(MS_CTX_ENABLE_HCCL, false);
#ifdef ENABLE_DEBUGGER
enable_mem_reuse_ = false;
set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, false);
#else
enable_mem_reuse_ = true;
set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, true);
#endif
enable_gpu_summary_ = true;
precompile_only_ = false;
auto_mixed_precision_flag_ = false;
enable_pynative_infer_ = false;
enable_pynative_hook_ = false;
enable_dynamic_mem_pool_ = true;
graph_memory_max_size_ = "0";
variable_memory_max_size_ = "0";
enable_loop_sink_ = target == kAscendDevice || target == kDavinciDevice;
profiling_mode_ = false;
profiling_options_ = "training_trace";
check_bprop_flag_ = false;
max_device_memory_ = kDefaultMaxDeviceMemory;
print_file_path_ = "";
enable_graph_kernel_ = false;
enable_sparse_ = false;
set_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY, true);
set_param<bool>(MS_CTX_PRECOMPILE_ONLY, false);
set_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG, false);
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, false);
set_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true);
set_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE, "0");
set_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE, "0");
set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, target == kAscendDevice || target == kDavinciDevice);
set_param<bool>(MS_CTX_ENABLE_PROFILING, false);
set_param<std::string>(MS_CTX_PROFILING_OPTIONS, "training_trace");
set_param<bool>(MS_CTX_CHECK_BPROP_FLAG, false);
set_param<float>(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory);
set_param<std::string>(MS_CTX_PRINT_FILE_PATH, "");
set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
backend_policy_ = policy_map_[policy];
}
std::shared_ptr<MsContext> MsContext::GetInstance() {
@ -106,54 +107,4 @@ std::string MsContext::backend_policy() const {
}
return "unknown";
}
void MsContext::set_execution_mode(int execution_mode) {
if (execution_mode != kGraphMode && execution_mode != kPynativeMode) {
MS_LOG(EXCEPTION) << "The execution mode is invalid!";
}
execution_mode_ = execution_mode;
}
bool MsContext::set_device_target(const std::string &target) {
if (kTargetSet.find(target) == kTargetSet.end()) {
MS_LOG(ERROR) << "invalid device target name: " << target;
return false;
}
if (target == kDavinciDevice) {
device_target_ = kAscendDevice;
} else {
device_target_ = target;
}
if (seter_) {
seter_(device_target_);
}
MS_LOG(INFO) << "ms set context device target:" << target;
return true;
}
bool MsContext::set_device_id(uint32_t device_id) {
device_id_ = device_id;
MS_LOG(INFO) << "ms set context device id:" << device_id;
return true;
}
void MsContext::set_tsd_ref(const std::string &op) {
if (op == "--") {
tsd_ref_--;
} else if (op == "++") {
tsd_ref_++;
} else {
tsd_ref_ = 0;
}
}
void MsContext::set_ge_ref(const std::string &op) {
if (op == "--") {
ge_ref_--;
} else if (op == "++") {
ge_ref_++;
} else {
ge_ref_ = 0;
}
}
} // namespace mindspore

View File

@ -49,6 +49,69 @@ const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice,
// The default max available device memory is 1024GB.
const float kDefaultMaxDeviceMemory = 1024;
// enum definition for MindSpore Context Parameter
enum MsCtxParam : unsigned {
// paramater of type bool
MS_CTX_TYPE_BOOL_BEGIN,
MS_CTX_AUTO_MIXED_PRECISION_FLAG = MS_CTX_TYPE_BOOL_BEGIN,
MS_CTX_CHECK_BPROP_FLAG,
MS_CTX_ENABLE_DUMP,
MS_CTX_ENABLE_DYNAMIC_MEM_POOL,
MS_CTX_ENABLE_GPU_SUMMARY,
MS_CTX_ENABLE_GRAPH_KERNEL,
MS_CTX_ENABLE_HCCL,
MS_CTX_ENABLE_LOOP_SINK,
MS_CTX_ENABLE_MEM_REUSE,
MS_CTX_ENABLE_PYNATIVE_HOOK,
MS_CTX_ENABLE_PYNATIVE_INFER,
MS_CTX_ENABLE_REDUCE_PRECISION,
MS_CTX_ENABLE_SPARSE,
MS_CTX_ENABLE_TASK_SINK,
MS_CTX_IR_FUSION_FLAG,
MS_CTX_IS_MULTI_GRAPH_SINK,
MS_CTX_IS_PYNATIVE_GE_INIT,
MS_CTX_PRECOMPILE_ONLY,
MS_CTX_ENABLE_PROFILING,
MS_CTX_SAVE_GRAPHS_FLAG,
MS_CTX_TYPE_BOOL_END,
// paramater of type int
MS_CTX_TYPE_INT_BEGIN = MS_CTX_TYPE_BOOL_END,
MS_CTX_EXECUTION_MODE = MS_CTX_TYPE_INT_BEGIN,
MS_CTX_TYPE_INT_END,
// paramater of type uint32
MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END,
MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN,
MS_CTX_GE_REF,
MS_CTX_MAX_CALL_DEPTH,
MS_CTX_TSD_REF,
MS_CTX_TYPE_UINT32_END,
// paramater of type float
MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END,
MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN,
MS_CTX_TYPE_FLOAT_END,
// paramater of type string
MS_CTX_TYPE_STRING_BEGIN = MS_CTX_TYPE_FLOAT_END,
MS_CTX_DEVICE_TARGET = MS_CTX_TYPE_STRING_BEGIN,
MS_CTX_GRAPH_MEMORY_MAX_SIZE,
MS_CTX_PRINT_FILE_PATH,
MS_CTX_PROFILING_OPTIONS,
MS_CTX_SAVE_DUMP_PATH,
MS_CTX_SAVE_GRAPHS_PATH,
MS_CTX_VARIABLE_MEMORY_MAX_SIZE,
MS_CTX_TYPE_STRING_END,
// parameter numbers of each type
NUM_BOOL_PARAMS = MS_CTX_TYPE_BOOL_END - MS_CTX_TYPE_BOOL_BEGIN,
NUM_INT_PARAMS = MS_CTX_TYPE_INT_END - MS_CTX_TYPE_INT_BEGIN,
NUM_UINT32_PARAMS = MS_CTX_TYPE_UINT32_END - MS_CTX_TYPE_UINT32_BEGIN,
NUM_FLOAT_PARAMS = MS_CTX_TYPE_FLOAT_END - MS_CTX_TYPE_FLOAT_BEGIN,
NUM_STRING_PARAMS = MS_CTX_TYPE_STRING_END - MS_CTX_TYPE_STRING_BEGIN
};
class MsContext {
public:
MsContext(const std::string &backend_policy, const std::string &target);
@ -62,156 +125,113 @@ class MsContext {
std::string backend_policy() const;
bool set_backend_policy(const std::string &policy);
int execution_mode() const { return execution_mode_; }
void set_execution_mode(int execution_mode);
bool enable_pynative_infer() const { return enable_pynative_infer_; }
void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; }
bool enable_pynative_hook() const { return enable_pynative_hook_; }
void set_enable_pynative_hook(bool enable_pynative_hook) { enable_pynative_hook_ = enable_pynative_hook; }
bool enable_task_sink() const { return enable_task_sink_; }
void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; }
bool precompile_only() const { return precompile_only_; }
std::string device_target() const { return device_target_; }
bool set_device_target(const std::string &target);
uint32_t device_id() const { return device_id_; }
bool set_device_id(uint32_t device_id);
// uint32_t max_call_depth_
uint32_t max_call_depth() const { return max_call_depth_; }
inline bool set_max_call_depth(uint32_t max_call_depth) {
max_call_depth_ = max_call_depth;
return true;
}
bool save_graphs_flag() const { return save_graphs_flag_; }
void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; }
std::string save_graphs_path() const { return save_graphs_path_; }
void set_save_graphs_path(const std::string &save_paths) { save_graphs_path_ = save_paths; }
bool IsGeInited() { return ge_ref_ > 0; }
void set_enable_hccl(bool enable_hccl) { enable_hccl_ = enable_hccl; }
bool enable_hccl() const { return enable_hccl_; }
bool ir_fusion_flag() const { return ir_fusion_flag_; }
bool loop_sink_flag() const { return enable_loop_sink_; }
void set_loop_sink_flag(bool enable_loop_sink) { enable_loop_sink_ = enable_loop_sink; }
void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; }
bool enable_mem_reuse() const { return enable_mem_reuse_; }
void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; }
bool enable_gpu_summary() const { return enable_gpu_summary_; }
void set_auto_mixed_precision_flag(bool auto_mixed_precision_flag) {
auto_mixed_precision_flag_ = auto_mixed_precision_flag;
}
bool auto_mixed_precision_flag() const { return auto_mixed_precision_flag_; }
void set_enable_reduce_precision(bool flag) { enable_reduce_precision_ = flag; }
bool enable_reduce_precision() const { return enable_reduce_precision_; }
void set_enable_dump(bool flag) { enable_dump_ = flag; }
bool enable_dump() const { return enable_dump_; }
void set_save_dump_path(const std::string &path) { save_dump_path_ = path; }
std::string save_dump_path() const { return save_dump_path_; }
bool IsTsdOpened() const { return tsd_ref_ > 0; }
void set_tsd_ref(const std::string &op);
uint32_t tsd_ref() const { return tsd_ref_; }
void set_ge_ref(const std::string &op);
uint32_t ge_ref() const { return ge_ref_; }
bool is_pynative_ge_init() { return is_pynative_ge_init_; }
void set_pynative_ge_init(bool flag) { is_pynative_ge_init_ = flag; }
bool is_multi_graph_sink() const { return is_multi_graph_sink_; }
void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; }
void set_enable_dynamic_mem_pool(bool enable_dynamic_mem_pool) { enable_dynamic_mem_pool_ = enable_dynamic_mem_pool; }
bool enable_dynamic_mem_pool() const { return enable_dynamic_mem_pool_; }
void set_graph_memory_max_size(const std::string &graph_memory_max_size) {
graph_memory_max_size_ = graph_memory_max_size;
}
void set_variable_memory_max_size(const std::string &variable_memory_max_size) {
variable_memory_max_size_ = variable_memory_max_size;
}
const std::string &variable_memory_max_size() const { return variable_memory_max_size_; }
const std::string &graph_memory_max_size() const { return graph_memory_max_size_; }
void set_enable_profiling(bool flag) { profiling_mode_ = flag; }
bool enable_profiling() const { return profiling_mode_; }
void set_profiling_options(const std::string &options) { profiling_options_ = options; }
std::string profiling_options() const { return profiling_options_; }
bool check_bprop_flag() const { return check_bprop_flag_; }
void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; }
void set_print_file_path(const std::string &file) { print_file_path_ = file; }
const std::string &print_file_path() const { return print_file_path_; }
float max_device_memory() const { return max_device_memory_; }
void set_max_device_memory(float max_device_memory) { max_device_memory_ = max_device_memory; }
void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; }
bool enable_graph_kernel() const { return enable_graph_kernel_; }
bool enable_sparse() const { return enable_sparse_; }
void set_enable_sparse(bool enable_sparse) { enable_sparse_ = enable_sparse; }
static void device_seter(DeviceSeter device) { seter_ = device; }
static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; }
std::thread tdt_print_;
template <typename T>
void set_param(MsCtxParam param, const T &value) {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
}
template <typename T>
const T &get_param(MsCtxParam param) const {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
}
template <typename T>
void increase_param(MsCtxParam param) {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
}
template <typename T>
void decrease_param(MsCtxParam param) {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
}
private:
inline static DeviceSeter seter_ = nullptr;
inline static DeviceTypeSeter device_type_seter_ = nullptr;
static std::shared_ptr<MsContext> inst_context_;
static std::map<std::string, MsBackendPolicy> policy_map_;
bool bool_params_[MsCtxParam::NUM_BOOL_PARAMS];
int int_params_[MsCtxParam::NUM_INT_PARAMS];
uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS];
float float_params_[MsCtxParam::NUM_FLOAT_PARAMS];
std::string string_params_[MsCtxParam::NUM_STRING_PARAMS];
MsBackendPolicy backend_policy_;
std::string device_target_;
uint32_t device_id_;
uint32_t max_call_depth_;
int execution_mode_;
bool enable_pynative_infer_;
bool enable_pynative_hook_;
bool save_graphs_flag_;
std::string save_graphs_path_;
uint32_t tsd_ref_;
uint32_t ge_ref_;
bool enable_task_sink_;
bool enable_hccl_;
bool precompile_only_;
bool ir_fusion_flag_;
bool auto_mixed_precision_flag_;
bool enable_reduce_precision_;
bool enable_loop_sink_;
bool enable_mem_reuse_;
bool enable_gpu_summary_;
bool enable_dump_;
std::string save_dump_path_;
bool is_multi_graph_sink_;
bool is_pynative_ge_init_;
bool enable_dynamic_mem_pool_;
std::string graph_memory_max_size_;
std::string variable_memory_max_size_;
bool profiling_mode_;
std::string profiling_options_;
bool check_bprop_flag_;
float max_device_memory_;
std::string print_file_path_;
bool enable_graph_kernel_;
bool enable_sparse_;
};
// set method implementation for type bool/int/uint32_t/float/std::string
template <>
inline void MsContext::set_param<bool>(MsCtxParam param, const bool &value) {
bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN] = value;
}
template <>
inline void MsContext::set_param<int>(MsCtxParam param, const int &value) {
int_params_[param - MS_CTX_TYPE_INT_BEGIN] = value;
}
template <>
inline void MsContext::set_param<uint32_t>(MsCtxParam param, const uint32_t &value) {
uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN] = value;
}
template <>
inline void MsContext::set_param<float>(MsCtxParam param, const float &value) {
float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN] = value;
}
template <>
inline void MsContext::set_param<std::string>(MsCtxParam param, const std::string &value) {
if (seter_ != nullptr && param == MS_CTX_DEVICE_TARGET) {
MS_LOG(INFO) << "ms set context device target:" << value;
seter_(value);
}
string_params_[param - MS_CTX_TYPE_STRING_BEGIN] = value;
}
// get method implementation for type bool/int/uint32_t/float/std::string
template <>
inline const bool &MsContext::get_param<bool>(MsCtxParam param) const {
return bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN];
}
template <>
inline const int &MsContext::get_param<int>(MsCtxParam param) const {
return int_params_[param - MS_CTX_TYPE_INT_BEGIN];
}
template <>
inline const uint32_t &MsContext::get_param<uint32_t>(MsCtxParam param) const {
return uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN];
}
template <>
inline const float &MsContext::get_param<float>(MsCtxParam param) const {
return float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN];
}
template <>
inline const std::string &MsContext::get_param<std::string>(MsCtxParam param) const {
return string_params_[param - MS_CTX_TYPE_STRING_BEGIN];
}
// increate method implementation for type uint32_t
template <>
inline void MsContext::increase_param<uint32_t>(MsCtxParam param) {
uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]++;
}
// decreate method implementation for type uint32_t
template <>
inline void MsContext::decrease_param<uint32_t>(MsCtxParam param) {
uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]--;
}
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_MS_CONTEXT_H_

View File

@ -42,7 +42,7 @@ class TestOptLib : public UT::Common {
parse::data_converter::ClearObjectCache();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
}
FuncGraphPtr RunTransform(FuncGraphPtr gbefore, const SubstitutionList &transform) {
equiv_node.clear();

View File

@ -112,7 +112,7 @@ TEST_F(TestHWInsertTransOp, test_insert_trans_op_for_single_output) {
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
auto fg = GetSingleOutputGraph("test_insert_trans_op_for_single_output", "before", "NC1HWC0");
// Do insert_trans_op_ pass of hardware opt
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();

View File

@ -112,7 +112,7 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect {
TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_single_output) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
auto kg = GetSingleOutputGraph("test_remove_internal_output_trans_op_for_single_output", "before");
// insert trans op for output
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();

View File

@ -104,7 +104,7 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transdata_split_fraz_nchw", "before");
std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);

View File

@ -83,7 +83,7 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_transdata_fusion", "before");
std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);

View File

@ -76,7 +76,7 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) {
*/
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_eliminate_5to4_4to5", "before");
// Renormalize func_graph to infer and set shape and type information.
std::vector<int> shp{2, 32, 224, 224};

View File

@ -71,15 +71,15 @@ OpExecInfoPtr ConstructOpExecInfo() {
TEST_F(TestPynativeExecute, TestCreateContext) {
auto ctx3 = MsContext::GetInstance();
ASSERT_EQ(ctx3->backend_policy(), "vm");
ASSERT_EQ(ctx3->device_target(), "CPU");
ASSERT_EQ(ctx3->get_param<std::string>(MS_CTX_DEVICE_TARGET), "CPU");
ctx3->set_backend_policy("ge_only");
ctx3->set_device_target("GPU");
ctx3->set_param<std::string>(MS_CTX_DEVICE_TARGET, "GPU");
auto ctx4 = MsContext::GetInstance();
ASSERT_EQ(ctx3.get(), ctx4.get());
ASSERT_EQ(ctx4->backend_policy(), "ge_only");
ASSERT_EQ(ctx4->device_target(), "GPU");
ASSERT_EQ(ctx4->get_param<std::string>(MS_CTX_DEVICE_TARGET), "GPU");
}
TEST_F(TestPynativeExecute, TestDefaultContext) {