From fcbb3e0edcdd5c4bb48e65fdf351c08643537a41 Mon Sep 17 00:00:00 2001 From: fary86 Date: Thu, 27 Aug 2020 17:24:08 +0800 Subject: [PATCH] Refactor ms_context implementation --- .../hccl/hcom_all_broadcast.cc | 2 +- .../kernel_compiler/hccl/hcom_all_gather.cc | 2 +- .../kernel_compiler/hccl/hcom_all_reduce.cc | 2 +- .../hccl/hcom_all_reduce_scatter.cc | 2 +- .../backend/kernel_compiler/kernel_query.cc | 3 +- .../backend/kernel_compiler/oplib/oplib.cc | 2 +- .../ascend/ascend_backend_optimization.cc | 46 +-- .../ascend/format_type/insert_trans_op.cc | 3 +- .../rectify_do_mask_kernel_info.cc | 2 +- .../common/common_backend_optimization.cc | 4 +- .../ccsrc/backend/optimizer/common/helper.cc | 3 +- .../backend/optimizer/common/pass_manager.cc | 4 +- .../pass/const_to_attr_strided_slice_grad.cc | 2 +- .../backend/session/ascend_control_parser.cc | 4 +- .../ccsrc/backend/session/ascend_session.cc | 28 +- .../ccsrc/backend/session/gpu_session.cc | 8 +- .../ccsrc/backend/session/infer_session.cc | 8 +- .../ccsrc/backend/session/session_basic.cc | 11 +- .../ccsrc/backend/session/session_basic.h | 2 +- mindspore/ccsrc/debug/data_dump_parser.cc | 2 +- mindspore/ccsrc/debug/debugger/debugger.cc | 2 +- mindspore/ccsrc/debug/dump_proto.cc | 2 +- mindspore/ccsrc/debug/e2e_dump.cc | 4 +- mindspore/ccsrc/frontend/optimizer/ad/grad.cc | 2 +- .../ccsrc/frontend/optimizer/ad/kprim.cc | 2 +- .../optimizer/irpass/arithmetic_simplify.cc | 6 +- .../ccsrc/frontend/optimizer/optimizer.h | 2 +- mindspore/ccsrc/frontend/optimizer/py_pass.cc | 4 +- .../parallel/graph_util/graph_info.cc | 2 +- .../frontend/parallel/step_auto_parallel.cc | 2 +- mindspore/ccsrc/pipeline/jit/action.cc | 20 +- mindspore/ccsrc/pipeline/jit/base.h | 2 +- mindspore/ccsrc/pipeline/jit/init.cc | 133 +++++--- mindspore/ccsrc/pipeline/jit/pass.cc | 2 +- mindspore/ccsrc/pipeline/jit/pipeline.cc | 19 +- mindspore/ccsrc/pipeline/jit/pipeline_ge.cc | 4 +- .../pipeline/jit/static_analysis/evaluator.cc | 7 +- .../pipeline/jit/static_analysis/evaluator.h | 2 +- .../pipeline/jit/static_analysis/prim.cc | 2 +- .../pipeline/pynative/pynative_execute.cc | 10 +- mindspore/ccsrc/pybind_api/ir/primitive_py.cc | 2 +- .../device/ascend/ascend_device_address.cc | 12 +- .../device/ascend/ascend_kernel_runtime.cc | 14 +- .../device/ascend/ascend_memory_manager.cc | 2 +- .../device/ascend/ascend_stream_assign.cc | 2 +- .../runtime/device/ascend/dump/data_dumper.cc | 2 +- .../device/ascend/kernel_select_ascend.cc | 2 +- .../ascend/profiling/profiling_manager.cc | 2 +- .../ascend/profiling/profiling_manager.h | 2 +- .../ascend/profiling/profiling_utils.cc | 6 +- .../runtime/device/gpu/gpu_kernel_runtime.cc | 6 +- .../device/gpu/gpu_memory_allocator.cc | 4 +- .../runtime/device/gpu/gpu_memory_manager.cc | 4 +- .../runtime/device/gpu/kernel_info_setter.cc | 2 +- .../ccsrc/runtime/device/kernel_adjust.cc | 4 +- .../ccsrc/runtime/device/kernel_runtime.cc | 12 +- .../ccsrc/runtime/device/memory_manager.cc | 2 +- mindspore/ccsrc/transform/graph_ir/convert.cc | 2 +- .../ccsrc/utils/context/context_extends.cc | 85 ++--- mindspore/ccsrc/utils/convert_utils_py.cc | 2 +- mindspore/ccsrc/utils/tensorprint_utils.cc | 2 +- mindspore/ccsrc/vm/backend.cc | 4 +- mindspore/ccsrc/vm/transform.cc | 18 +- mindspore/context.py | 127 ++++---- mindspore/core/abstract/prim_others.cc | 2 +- mindspore/core/ir/anf.cc | 4 +- mindspore/core/ir/func_graph_cloner.cc | 2 +- mindspore/core/ir/meta_func_graph.cc | 2 +- mindspore/core/utils/ms_context.cc | 123 +++----- mindspore/core/utils/ms_context.h | 296 ++++++++++-------- tests/ut/cpp/optimizer/lib_test.cc | 2 +- .../format_type/insert_trans_op_test.cc | 2 +- .../remove_internal_output_test.cc | 2 +- .../ascend/ir_fission/transdata_split_test.cc | 2 +- .../transpose_transdata_fusion_test.cc | 2 +- .../pass/eliminate_redundant_op_test.cc | 2 +- .../ut/cpp/pynative/pynative_execute_test.cc | 6 +- 77 files changed, 582 insertions(+), 554 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc index b0be4c4ef1..d031fed921 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc @@ -25,7 +25,7 @@ bool HcomAllBroadCastKernel::Launch(const std::vector &inputs, const std::vector & /*outputs*/, void *stream_ptr) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_task_sink()) { + if (context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK)) { return true; } if (inputs.empty() || hccl_data_type_list_.empty()) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc index 088ca37db9..fc4832ae6f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc @@ -24,7 +24,7 @@ bool HcomAllGatherKernel::Launch(const std::vector &inputs, const st const std::vector &outputs, void *stream_ptr) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_task_sink()) { + if (context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK)) { return true; } if (inputs.empty() || hccl_data_type_list_.empty()) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc index 70bf880f64..4d8cd690f0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc @@ -24,7 +24,7 @@ bool HcomAllReduceKernel::Launch(const std::vector &inputs, const st const std::vector &outputs, void *stream_ptr) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_task_sink()) { + if (context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK)) { return true; } if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc index cbeab9e7e1..08281c6030 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc @@ -25,7 +25,7 @@ bool HcomAllReduceScatterKernel::Launch(const std::vector &inputs, const std::vector &outputs, void *stream_ptr) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_task_sink()) { + if (context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK)) { return true; } if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc index d3555bad07..4e32121cbb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc @@ -101,7 +101,8 @@ void KernelQuery(const CNodePtr &kernel_node, std::vectorenable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) { + if (context_ptr->get_param(MS_CTX_ENABLE_GRAPH_KERNEL) && + IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) { kernel_type = KernelType::AKG_KERNEL; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc index 4442acc314..69209ed985 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc @@ -328,7 +328,7 @@ std::shared_ptr OpLib::FindOp(const std::string &op_name, OpImplyType im } auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); - bool is_gpu = (context->device_target() == kGPUDevice); + bool is_gpu = (context->get_param(MS_CTX_DEVICE_TARGET) == kGPUDevice); if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) { MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) << ", current op num: " << op_info_.size(); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 9401696fa4..3c4074c89e 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -249,8 +249,8 @@ void AscendMixPrecision(const std::shared_ptr &kernel_grap void AscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } @@ -262,7 +262,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr(); auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); - if (context_ptr->execution_mode() == kPynativeMode) { + if (context_ptr->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); } else { @@ -276,7 +276,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrenable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { + if (context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param(MS_CTX_ENABLE_LOOP_SINK) && + ConfigManager::GetInstance().iter_num() > 1) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); @@ -296,12 +297,12 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (!context_ptr->ir_fusion_flag()) { + if (!context_ptr->get_param(MS_CTX_IR_FUSION_FLAG)) { MS_LOG(INFO) << "IRFusion is not enable, skip"; return; } - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } @@ -331,8 +332,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } @@ -367,7 +368,8 @@ void AscendBackendOptimization(const std::shared_ptr &kern auto other2_pm = std::make_shared("other2_pm"); other2_pm->AddPass(std::make_shared()); other2_pm->AddPass(std::make_shared()); - if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { + if (context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param(MS_CTX_ENABLE_LOOP_SINK) && + ConfigManager::GetInstance().iter_num() > 1) { other2_pm->AddPass(std::make_shared()); } other2_pm->AddPass(std::make_shared()); @@ -388,11 +390,11 @@ void AscendBackendGraphKernelOpt(const std::shared_ptr &ke bool is_before_kernel_select) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (!(context_ptr->enable_graph_kernel())) { + if (!(context_ptr->get_param(MS_CTX_ENABLE_GRAPH_KERNEL))) { return; } - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } @@ -418,11 +420,11 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr &kern bool is_before_kernel_select) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (!(context_ptr->enable_graph_kernel())) { + if (!(context_ptr->get_param(MS_CTX_ENABLE_GRAPH_KERNEL))) { return; } - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } @@ -447,11 +449,11 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr &kern void AscendBackendAddAtomicClean(const std::shared_ptr &kernel_graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (!(context_ptr->enable_graph_kernel())) { + if (!(context_ptr->get_param(MS_CTX_ENABLE_GRAPH_KERNEL))) { return; } - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } @@ -473,12 +475,12 @@ void AscendBackendAddAtomicClean(const std::shared_ptr &ke void AscendBackendUBFusionOptimization(const std::shared_ptr &kernel_graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (!context_ptr->ir_fusion_flag()) { + if (!context_ptr->get_param(MS_CTX_IR_FUSION_FLAG)) { MS_LOG(INFO) << "UBFusion is not enable, skip"; return; } - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc index 88afc5444e..8df1909b55 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc @@ -53,7 +53,8 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An } auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) { + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode && + !ms_context->get_param(MS_CTX_ENABLE_PYNATIVE_HOOK)) { if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) { return new_node; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc index cff8ce5673..b7544a37ec 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc @@ -44,7 +44,7 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con auto cnode = node->cast(); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() == kPynativeMode) { + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { return RectifyKernelInfoInPynativeProcess(node); } if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) { diff --git a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc index 236535ef3b..e8f51b6b30 100644 --- a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc @@ -33,8 +33,8 @@ void BackendCommonOptimization(const std::shared_ptr &kern MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id(); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index ffe9be054d..f9770fc949 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -392,7 +392,8 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) { bool IsNopNode(const AnfNodePtr &node) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->device_target() != kAscendDevice && context_ptr->device_target() != kGPUDevice) { + if (context_ptr->get_param(MS_CTX_DEVICE_TARGET) != kAscendDevice && + context_ptr->get_param(MS_CTX_DEVICE_TARGET) != kGPUDevice) { return false; } static std::unordered_set nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName, diff --git a/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc b/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc index fc83d3e389..e0ad2b53fa 100644 --- a/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc +++ b/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc @@ -40,8 +40,8 @@ bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector } auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc index 7d4b54b23a..82d1d981ad 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc @@ -114,7 +114,7 @@ const AnfNodePtr ConstToAttrStridedSliceGradPass::Process(const FuncGraphPtr &gr auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->device_target() == kAscendDevice) { + if (ms_context->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice) { if (!CheckAttrs(strided_slice_grad)) { MS_LOG(INFO) << "Check strided_slice_grad's attrs failed, graph not changed"; return nullptr; diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc index 9430783686..367e7d2ec6 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -359,11 +359,11 @@ void AscendControlParser::ExecutorValidate(NotNull root_graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - auto save_graphs_path = context_ptr->save_graphs_path(); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } - if (context_ptr->save_graphs_flag()) { + if (context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { std::string file_path = save_graphs_path + "/after_erase_label_and_parameter.ir"; DumpIR(file_path, root_graph.get()); } diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index ee7363b060..29a0795afb 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -253,7 +253,7 @@ void AscendSession::BuildGraph(GraphId graph_id) { debugger_->PreExecute(graph); } #endif - if (ms_context->precompile_only()) { + if (ms_context->get_param(MS_CTX_PRECOMPILE_ONLY)) { MS_LOG(INFO) << "Precompile only, stop in build kernel step"; } else { // alloc memory, including static memory and dynamic memory @@ -278,8 +278,8 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { child_graph->SetExecOrderByDefault(); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } @@ -436,7 +436,7 @@ void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const { } auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() == kGraphMode) { + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { if (raise_precision_count > 0) { MS_LOG(WARNING) << "There has " << raise_precision_count << " node/nodes used raise precision to selected the kernel!"; @@ -481,8 +481,8 @@ void AscendSession::AdjustKernel(const std::shared_ptr &kernel_grap device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } @@ -601,11 +601,11 @@ void AscendSession::DumpAllGraphs(const std::vector &all_graphs) #ifdef ENABLE_DUMP_IR auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); if (!save_graphs) { return; } - auto save_graphs_path = context_ptr->save_graphs_path(); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } @@ -733,7 +733,7 @@ void AscendSession::MergeGraphExecOrder() { if (graph_order.size() > 1) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (!context_ptr->enable_task_sink()) { + if (!context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK)) { MS_LOG(EXCEPTION) << "Control sink network should run with task-sink mode!"; } } @@ -920,8 +920,8 @@ void AscendSession::IrFusionPass(const NotNull graph, NotNullsave_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs) { if (save_graphs_path.empty()) { save_graphs_path = "."; @@ -947,7 +947,7 @@ void AscendSession::SelectKernel(NotNull root_graph) { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() == kGraphMode) { + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { if (raise_precision_count > 0) { MS_LOG(WARNING) << "There are " << raise_precision_count << " node/nodes used raise precision to selected the kernel!"; @@ -992,8 +992,8 @@ void AscendSession::RecurseSelectKernelInfo(NotNull graph, auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs) { if (save_graphs_path.empty()) { save_graphs_path = "."; diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index bf2165015d..85d042746d 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -76,7 +76,7 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); - if (!CheckInModeBlackList(kernel_graph) && context_ptr->execution_mode() != kPynativeMode) { + if (!CheckInModeBlackList(kernel_graph) && context_ptr->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); @@ -154,7 +154,7 @@ void GPUSession::LoadInputData(const std::shared_ptr &kernel_graph, auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); auto tensor_address = std::dynamic_pointer_cast(tensor->device_address()); bool need_sync = false; - if (ms_context->enable_pynative_infer()) { + if (ms_context->get_param(MS_CTX_ENABLE_PYNATIVE_INFER)) { if (tensor_address == nullptr || tensor_address != device_address) { need_sync = true; } @@ -223,7 +223,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList // Prepare ms context info for dump .pb graph auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); // Optimize Optimize(graph); // Select kernel build info @@ -290,7 +290,7 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vectorenable_gpu_summary()) { + if (context_ptr->get_param(MS_CTX_ENABLE_GPU_SUMMARY)) { Summary(kernel_graph.get()); } #ifdef ENABLE_DEBUGGER diff --git a/mindspore/ccsrc/backend/session/infer_session.cc b/mindspore/ccsrc/backend/session/infer_session.cc index e9985d8a4d..068a378f75 100644 --- a/mindspore/ccsrc/backend/session/infer_session.cc +++ b/mindspore/ccsrc/backend/session/infer_session.cc @@ -268,7 +268,7 @@ void MSInferSession::RegAllOp() { return; } Initialized = true; - MsContext::GetInstance()->set_execution_mode(kGraphMode); + MsContext::GetInstance()->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); Py_Initialize(); auto c_expression = PyImport_ImportModule("mindspore._c_expression"); if (c_expression == nullptr) { @@ -357,13 +357,13 @@ Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) { MS_LOG(ERROR) << "Get Context failed!"; return FAILED; } - ms_context->set_execution_mode(kGraphMode); - ms_context->set_device_id(device_id); + ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); + ms_context->set_param(MS_CTX_DEVICE_ID, device_id); auto ajust_device = AjustTargetName(device); if (ajust_device == "") { return FAILED; } - ms_context->set_device_target(device); + ms_context->set_param(MS_CTX_DEVICE_TARGET, device); if (!context::OpenTsd(ms_context)) { MS_LOG(ERROR) << "Session init OpenTsd failed!"; return FAILED; diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 8a195f32db..865f7adc5d 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -93,10 +93,11 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o // if in paynative mode,data only copyed to host when user want to print data auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() != kPynativeMode && ms_context->device_target() != kGPUDevice) { + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode && + ms_context->get_param(MS_CTX_DEVICE_TARGET) != kGPUDevice) { tensor->set_need_sync(true); } - if (ms_context->execution_mode() != kPynativeMode) { + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { tensor->SetNeedWait(true); } tensor->set_dirty(false); @@ -938,7 +939,7 @@ bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0); - if (ms_context->enable_pynative_infer()) { + if (ms_context->get_param(MS_CTX_ENABLE_PYNATIVE_INFER)) { return tensor->device_address().get() == nullptr || tensor->device_address() != device_address; } if (tensor->is_dirty()) { @@ -979,7 +980,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr &kernel_grap MS_EXCEPTION_IF_NULL(input_node); if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); - if (ms_context->execution_mode() == kPynativeMode || + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode || AnfAlgo::IsParameterWeight(input_node->cast())) { tensor->set_device_address(device_address); } @@ -1177,7 +1178,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: if (backend_anf != nullptr) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->execution_mode() == kPynativeMode) { + if (context_ptr->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { return backend_anf; } diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 1e963722be..863a222034 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -118,7 +118,7 @@ class SessionBasic : public std::enable_shared_from_this { debugger_ = Debugger::GetInstance(); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - debugger_->Init(device_id_, ms_context->device_target()); + debugger_->Init(device_id_, ms_context->get_param(MS_CTX_DEVICE_TARGET)); } #endif diff --git a/mindspore/ccsrc/debug/data_dump_parser.cc b/mindspore/ccsrc/debug/data_dump_parser.cc index 4666ba20ba..58876ecaab 100644 --- a/mindspore/ccsrc/debug/data_dump_parser.cc +++ b/mindspore/ccsrc/debug/data_dump_parser.cc @@ -53,7 +53,7 @@ bool DataDumpParser::DumpEnabled() const { auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); - if (context->execution_mode() == kPynativeMode) { + if (context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { MS_LOG(EXCEPTION) << "[DataDump] PyNative mode not support data dump"; } return true; diff --git a/mindspore/ccsrc/debug/debugger/debugger.cc b/mindspore/ccsrc/debug/debugger/debugger.cc index d57537834f..3b42039214 100644 --- a/mindspore/ccsrc/debug/debugger/debugger.cc +++ b/mindspore/ccsrc/debug/debugger/debugger.cc @@ -142,7 +142,7 @@ void Debugger::EnableDebugger() { // switch memory reuse on or off auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - context_ptr->set_enable_mem_reuse(partial_memory_); + context_ptr->set_param(MS_CTX_ENABLE_MEM_REUSE, partial_memory_); // print some message about memory reuse to user if (partial_memory_) { MS_LOG(WARNING) << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first " diff --git a/mindspore/ccsrc/debug/dump_proto.cc b/mindspore/ccsrc/debug/dump_proto.cc index cd8a2a982d..0c8d912118 100644 --- a/mindspore/ccsrc/debug/dump_proto.cc +++ b/mindspore/ccsrc/debug/dump_proto.cc @@ -530,7 +530,7 @@ void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) { MS_LOG(ERROR) << "ms_context is nullptr"; return; } - auto save_graphs_path = ms_context->save_graphs_path(); + auto save_graphs_path = ms_context->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } diff --git a/mindspore/ccsrc/debug/e2e_dump.cc b/mindspore/ccsrc/debug/e2e_dump.cc index 147582a360..cbf914e7fd 100644 --- a/mindspore/ccsrc/debug/e2e_dump.cc +++ b/mindspore/ccsrc/debug/e2e_dump.cc @@ -112,7 +112,7 @@ bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); // dump_enable_ is true, close mem reuse - context_ptr->set_enable_mem_reuse(!dump_enable_); + context_ptr->set_param(MS_CTX_ENABLE_MEM_REUSE, !dump_enable_); trans_flag_ = trans_flag; dump_mode_ = mode; dump_path_ = path; @@ -135,7 +135,7 @@ bool Dump::SetDumpConfFromJsonFile() { } auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - auto id = context_ptr->device_id(); + auto id = context_ptr->get_param(MS_CTX_DEVICE_ID); char real_path[PATH_MAX] = {0}; if (nullptr == realpath(config_path_str, real_path)) { MS_LOG(ERROR) << "Env e2e dump path error, " << config_path_str; diff --git a/mindspore/ccsrc/frontend/optimizer/ad/grad.cc b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc index bc08bd94bf..c41efe4930 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/grad.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc @@ -34,7 +34,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt manager_ptr->AddFuncGraph(func_graph); auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) { - if (MsContext::GetInstance()->is_multi_graph_sink()) { + if (MsContext::GetInstance()->get_param(MS_CTX_IS_MULTI_GRAPH_SINK)) { if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); } diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index 3b296b5880..4b3c1fd745 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -182,7 +182,7 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) { auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); - bool check_bprop_flag = context->check_bprop_flag(); + bool check_bprop_flag = context->get_param(MS_CTX_CHECK_BPROP_FLAG); // Skip checking if check_bprop not set if (!check_bprop_flag) { return; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc index b69fb20bee..a306919d2f 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc @@ -29,7 +29,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr PConstant const_2(node); PConstant any_const(node); - if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { + if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { MATCH_REPLACE(node, x + zero_, x); // Add by zero MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero @@ -41,7 +41,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr } // Prim Eliminate (identity) MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x); - if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { + if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { return nullptr; } @@ -75,7 +75,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr } AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { + if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { return nullptr; } PatternNode x, y; diff --git a/mindspore/ccsrc/frontend/optimizer/optimizer.h b/mindspore/ccsrc/frontend/optimizer/optimizer.h index bc60181587..1a3d31654f 100644 --- a/mindspore/ccsrc/frontend/optimizer/optimizer.h +++ b/mindspore/ccsrc/frontend/optimizer/optimizer.h @@ -181,7 +181,7 @@ class Optimizer : public std::enable_shared_from_this { } }; use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); - if (is_on_debug_ && MsContext::GetInstance()->save_graphs_flag()) { + if (is_on_debug_ && MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; auto fg_name = "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass.cc b/mindspore/ccsrc/frontend/optimizer/py_pass.cc index 6473b9825a..95a90b44d6 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass.cc +++ b/mindspore/ccsrc/frontend/optimizer/py_pass.cc @@ -217,8 +217,8 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph void DrawNode(string name, AnfNodePtr node) { auto context_ptr = MsContext::GetInstance(); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + auto save_graphs_path = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc index 6b360e4a97..c6d8aeb22a 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc @@ -44,7 +44,7 @@ std::vector FindPrimtive(const FuncGraphPtr &graph, const std::str } void DumpGraph(const FuncGraphPtr &root, const std::string &name) { - if (MsContext::GetInstance()->save_graphs_flag()) { + if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { draw::Draw(name + ".dot", root); DumpIR(name + ".ir", root); ExportIR(name + ".dat", "0", root); diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 9705862a44..8944fa011e 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -69,7 +69,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { struct timeval start_time, end_time; (void)gettimeofday(&start_time, nullptr); - if (MsContext::GetInstance()->save_graphs_flag()) { + if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root); } MS_LOG(INFO) << "Now entering step auto parallel"; diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 109c27af6d..7470367f62 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -271,7 +271,7 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector &passes) if (!result) { MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first; } - if (MsContext::GetInstance()->save_graphs_flag() && res->func_graph() != nullptr) { + if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG) && res->func_graph() != nullptr) { auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first; auto func_graph = res->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); @@ -295,20 +295,20 @@ bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, static bool IsCtrlSink() { auto ms_ctx = MsContext::GetInstance(); - if (ms_ctx->execution_mode() != kGraphMode) { + if (ms_ctx->get_param(MS_CTX_EXECUTION_MODE) != kGraphMode) { return false; } - std::string device_target = ms_ctx->device_target(); + std::string device_target = ms_ctx->get_param(MS_CTX_DEVICE_TARGET); if (device_target != kAscendDevice) { return false; } - if (!ms_ctx->enable_task_sink()) { + if (!ms_ctx->get_param(MS_CTX_ENABLE_TASK_SINK)) { return false; } - if (!ms_ctx->is_multi_graph_sink()) { + if (!ms_ctx->get_param(MS_CTX_IS_MULTI_GRAPH_SINK)) { return false; } return true; @@ -325,13 +325,13 @@ bool TaskEmitAction(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(context_ptr); if (CompileGraphs::ContainMixedTarget(func_graph)) { bc_ptr->set_is_multi_graph_sink(false); - context_ptr->set_is_multi_graph_sink(false); - context_ptr->set_loop_sink_flag(false); - } else if (context_ptr->execution_mode() != kPynativeMode) { - std::string device_target = context_ptr->device_target(); + context_ptr->set_param(MS_CTX_IS_MULTI_GRAPH_SINK, false); + context_ptr->set_param(MS_CTX_ENABLE_LOOP_SINK, false); + } else if (context_ptr->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { + std::string device_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); if (device_target == kAscendDevice && backend != kMsVm) { bc_ptr->set_is_multi_graph_sink(true); - context_ptr->set_is_multi_graph_sink(true); + context_ptr->set_param(MS_CTX_IS_MULTI_GRAPH_SINK, true); } } diff --git a/mindspore/ccsrc/pipeline/jit/base.h b/mindspore/ccsrc/pipeline/jit/base.h index 595c2decd4..38cdcd19a6 100644 --- a/mindspore/ccsrc/pipeline/jit/base.h +++ b/mindspore/ccsrc/pipeline/jit/base.h @@ -49,7 +49,7 @@ inline std::string GetFilePathName(const std::string &file_name) { if (ms_context == nullptr) { MS_LOG(EXCEPTION) << "ms_context is nullptr"; } - auto save_graphs_path = ms_context->save_graphs_path(); + auto save_graphs_path = ms_context->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index d89dc44448..2ee67882f4 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -48,6 +48,56 @@ using OpLib = mindspore::kernel::OpLib; using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy; using ParallelContext = mindspore::parallel::ParallelContext; using CostModelContext = mindspore::parallel::CostModelContext; +using mindspore::MsCtxParam; + +namespace mindspore { +void MsCtxSetParameter(std::shared_ptr ctx, MsCtxParam param, const py::object &value) { + MS_LOG(DEBUG) << "set param(" << param << ") with value '" << py::str(value) << "' of type '" + << py::str(value.get_type()) << "'."; + if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END && py::isinstance(value)) { + ctx->set_param(param, value.cast()); + return; + } + if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END && py::isinstance(value)) { + ctx->set_param(param, value.cast()); + return; + } + if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END && py::isinstance(value)) { + ctx->set_param(param, value.cast()); + return; + } + if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END && py::isinstance(value)) { + ctx->set_param(param, value.cast()); + return; + } + if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END && py::isinstance(value)) { + ctx->set_param(param, value.cast()); + return; + } + + MS_LOG(EXCEPTION) << "Got illegal param " << param << " and value with type " << py::str(value.get_type()); +} + +py::object MsCtxGetParameter(const std::shared_ptr &ctx, MsCtxParam param) { + if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END) { + return py::bool_(ctx->get_param(param)); + } + if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END) { + return py::int_(ctx->get_param(param)); + } + if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END) { + return py::int_(ctx->get_param(param)); + } + if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END) { + return py::float_(ctx->get_param(param)); + } + if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END) { + return py::str(ctx->get_param(param)); + } + + MS_LOG(EXCEPTION) << "Got illegal param " << param << "."; +} +} // namespace mindspore // Interface with python PYBIND11_MODULE(_c_expression, m) { @@ -101,53 +151,48 @@ PYBIND11_MODULE(_c_expression, m) { (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); + (void)m.def("ms_ctx_get_param", &mindspore::MsCtxGetParameter, "Get value of specified paramter."); + (void)m.def("ms_ctx_set_param", &mindspore::MsCtxSetParameter, "Set value for specified paramter."); + + (void)py::enum_(*m, "ms_ctx_param", py::arithmetic()) + .value("auto_mixed_precision_flag", MsCtxParam::MS_CTX_AUTO_MIXED_PRECISION_FLAG) + .value("check_bprop_flag", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG) + .value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP) + .value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL) + .value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY) + .value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL) + .value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL) + .value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK) + .value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE) + .value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK) + .value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER) + .value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION) + .value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE) + .value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK) + .value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG) + .value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK) + .value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT) + .value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY) + .value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING) + .value("save_graphs_flag", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG) + .value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY) + .value("execution_mode", MsCtxParam::MS_CTX_EXECUTION_MODE) + .value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET) + .value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE) + .value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH) + .value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS) + .value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH) + .value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH) + .value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE) + .value("device_id", MsCtxParam::MS_CTX_DEVICE_ID) + .value("ge_ref", MsCtxParam::MS_CTX_GE_REF) + .value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH) + .value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF); + (void)py::class_>(m, "MSContext") .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") .def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.") - .def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.") - .def("get_execution_mode", &mindspore::MsContext::execution_mode, "Get execution mode.") - .def("set_execution_mode", &mindspore::MsContext::set_execution_mode, "Set execution mode.") - .def("set_precompile_only", &mindspore::MsContext::set_precompile_only, "Set enable precompile only.") - .def("get_precompile_only", &mindspore::MsContext::precompile_only, "Get enable precompile only.") - .def("get_device_target", &mindspore::MsContext::device_target, "Get device target.") - .def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.") - .def("get_device_id", &mindspore::MsContext::device_id, "Get device id.") - .def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.") - .def("get_max_call_depth", &mindspore::MsContext::max_call_depth, "Get max call depth.") - .def("set_max_call_depth", &mindspore::MsContext::set_max_call_depth, "Set max call depth.") - .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") - .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.") - .def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag, - "Get whether to enable auto mixed precision.") - .def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag, - "Set whether to enable auto mixed precision.") - .def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision, - "Get whether to enable reduce precision.") - .def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision, - "Set whether to enable reduce precision.") - .def("get_save_graphs_path", &mindspore::MsContext::save_graphs_path, "Get save graphs path.") - .def("set_save_graphs_path", &mindspore::MsContext::set_save_graphs_path, "Set save graphs path.") - .def("get_enable_dump", &mindspore::MsContext::enable_dump, "Get whether to enable dump.") - .def("set_enable_dump", &mindspore::MsContext::set_enable_dump, "Set whether to enable dump.") - .def("get_save_dump_path", &mindspore::MsContext::save_dump_path, "Get path to dump.") - .def("set_save_dump_path", &mindspore::MsContext::set_save_dump_path, "Set path to dump.") - .def("set_graph_memory_max_size", &mindspore::MsContext::set_graph_memory_max_size, "set graph memory max size.") - .def("set_variable_memory_max_size", &mindspore::MsContext::set_variable_memory_max_size, - "set variable memory max size") - .def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.") - .def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.") - .def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.") - .def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.") - .def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.") - .def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.") - .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") - .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.") - .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.") - .def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel, - "Set the GraphKernel switch to on or off.") - .def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.") - .def("get_enable_sparse", &mindspore::MsContext::enable_sparse, "Get whether to enable sparsity.") - .def("set_enable_sparse", &mindspore::MsContext::set_enable_sparse, "Set whether to enable sparsity."); + .def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy."); (void)py::class_>(m, "MpiConfig") .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 538a8894f5..0172adb793 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -271,7 +271,7 @@ void InitOpt(const ResourcePtr &res) { g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (!(context_ptr->enable_graph_kernel())) { + if (!(context_ptr->get_param(MS_CTX_ENABLE_GRAPH_KERNEL))) { g_pass_opts["opt_graph_kernel_a"]->set_enable(false); g_pass_opts["opt_graph_kernel_b"]->set_enable(false); } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 76244c2186..c52855e6b7 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -88,7 +88,7 @@ std::string GetBaseNameForIR(int stage_idx, const std::string &action_name) { if (ms_context == nullptr) { MS_LOG(EXCEPTION) << "ms_context is nullptr"; } - auto save_graphs_path = ms_context->save_graphs_path(); + auto save_graphs_path = ms_context->get_param(MS_CTX_SAVE_GRAPHS_PATH); if (save_graphs_path.empty()) { save_graphs_path = "."; } @@ -646,7 +646,7 @@ void Pipeline::Run() { if (!result) { MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first; } - if (MsContext::GetInstance()->save_graphs_flag() && resource_->func_graph() != nullptr) { + if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG) && resource_->func_graph() != nullptr) { auto graph = resource_->func_graph(); if (graph != nullptr) { user_graph = graph; @@ -688,7 +688,7 @@ void Pipeline::Run() { MsProfile::Reset(); #endif - if (MsContext::GetInstance()->save_graphs_flag() && (user_graph != nullptr)) { + if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG) && (user_graph != nullptr)) { std::string user_graph_file = GetFilePathName("ModelDigraph.dot"); MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file; draw::DrawUserFuncGraph(user_graph_file, user_graph); @@ -710,7 +710,7 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef if (!succ) { MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; } - if (MsContext::GetInstance()->execution_mode() == 0 && !converted->isa()) { + if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) == 0 && !converted->isa()) { MS_EXCEPTION(TypeError) << "For 'graph mode', the " << i << "th arg: " << converted->ToString() << " is not tensor."; } @@ -891,7 +891,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc // Convert CNodeList to LinConvertResult. ConfigManager::GetInstance().set_iter_num(1); auto runner = convert_fn({app_init}, ""); - if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { + if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { backend->Link(runner.graph_id); } ConfigManager::GetInstance().set_iter_num(size); @@ -965,10 +965,11 @@ void InitHccl() { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); (void)context::OpenTsd(ms_context); - uint32_t device_id = ms_context->device_id(); - std::string device_name = ms_context->device_target(); - ms_context->set_enable_hccl(true); - if (ms_context->backend_policy() == "ms" && ms_context->device_target() == kAscendDevice) { + uint32_t device_id = ms_context->get_param(MS_CTX_DEVICE_ID); + std::string device_name = ms_context->get_param(MS_CTX_DEVICE_TARGET); + ms_context->set_param(MS_CTX_ENABLE_HCCL, true); + if (ms_context->backend_policy() == "ms" && + ms_context->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice) { auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id); MS_EXCEPTION_IF_NULL(runtime_instance); if (!runtime_instance->Init()) { diff --git a/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc b/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc index 9c1b7f714d..23b46c604c 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc @@ -214,7 +214,7 @@ bool AddDFGraph(const std::map &info, const py::di return false; } - if (MsContext::GetInstance()->save_graphs_flag()) { + if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { convertor.DrawComputeGraph(GetFilePathName("ge_graph.dot")); // for debug convertor.DrawInitGraph(GetFilePathName("init_graph.dot")); // for debug convertor.DrawSaveCheckpointGraph(GetFilePathName("save_checkpoint_graph.dot")); // for debug @@ -244,7 +244,7 @@ FuncGraphPtr BuildDFGraph(const std::map &info, co } FuncGraphPtr anf_graph = info.at(phase)->func_graph; - if (MsContext::GetInstance()->save_graphs_flag()) { + if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { draw::Draw(GetFilePathName("anf_graph.dot"), anf_graph); // for debug DumpIR(GetFilePathName("anf_graph.ir"), anf_graph, true); } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index be45af748a..2584d16958 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -118,8 +118,9 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr << ", current function call depth: " << engine->function_call_depth(); AbstractBasePtr ret_base = nullptr; engine->IncreaseFunctionCallDepth(); - if (engine->function_call_depth() > MsContext::GetInstance()->max_call_depth()) { - MS_LOG(EXCEPTION) << "Exceed function call depth limit " << MsContext::GetInstance()->max_call_depth() << "."; + if (engine->function_call_depth() > MsContext::GetInstance()->get_param(MS_CTX_MAX_CALL_DEPTH)) { + MS_LOG(EXCEPTION) << "Exceed function call depth limit " + << MsContext::GetInstance()->get_param(MS_CTX_MAX_CALL_DEPTH) << "."; } std::vector nodes = FastShadowSort(func_node); for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { @@ -409,7 +410,7 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg bparams.push_back(SensitivityTransform(orig_func_)); auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); - bool enable_sparse = context->enable_sparse(); + bool enable_sparse = context->get_param(MS_CTX_ENABLE_SPARSE); (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), [&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { if (enable_sparse && arg_spec->isa()) { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h index af597d1d33..15e97caa9c 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h @@ -62,7 +62,7 @@ class Evaluator : public Base { virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) { auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); - bool enable_sparse = context->enable_sparse(); + bool enable_sparse = context->get_param(MS_CTX_ENABLE_SPARSE); if (!enable_sparse) { return nullptr; } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 5490221531..f66ab77101 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -290,7 +290,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { if (abs_base->isa()) { auto arg_tensor = dyn_cast(abs_base); dic["shape"] = arg_tensor->shape()->shape(); - if (MsContext::GetInstance()->execution_mode() == kGraphMode) { + if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { const auto &min_shape = arg_tensor->shape()->min_shape(); const auto &max_shape = arg_tensor->shape()->max_shape(); if (!min_shape.empty() && !max_shape.empty()) { diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 9d82d9a03d..a4a4870927 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -558,8 +558,8 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat MS_EXCEPTION_IF_NULL(op_exec_info); MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; auto ms_context = MsContext::GetInstance(); - ms_context->set_enable_pynative_infer(true); - std::string device_target = ms_context->device_target(); + ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_INFER, true); + std::string device_target = ms_context->get_param(MS_CTX_DEVICE_TARGET); if (device_target != kAscendDevice && device_target != kGPUDevice) { MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; } @@ -567,7 +567,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat if (session == nullptr) { session = session::SessionFactory::Get().Create(device_target); MS_EXCEPTION_IF_NULL(session); - session->Init(ms_context->device_id()); + session->Init(ms_context->get_param(MS_CTX_DEVICE_ID)); } std::vector input_tensors; @@ -578,7 +578,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat session->BuildOpAsync(op_exec_info.get(), graph_info, input_tensors, tensors_mask); EraseValueNodeTensor(tensors_mask, &input_tensors); py::tuple result = session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors); - ms_context->set_enable_pynative_infer(false); + ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_INFER, false); *status = PYNATIVE_SUCCESS; MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms"; return result; @@ -1308,7 +1308,7 @@ void PynativeExecutor::Clear(const std::string &flag) { // Maybe exit in the pynative runing op, so need reset pynative flag. auto ms_context = MsContext::GetInstance(); if (ms_context != nullptr) { - ms_context->set_enable_pynative_infer(false); + ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_INFER, false); } ConfigManager::GetInstance().ResetIterNum(); return; diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index 2ca20c9f5a..f26093ef03 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -89,7 +89,7 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args) MS_EXCEPTION(ValueError) << "For user define net bprop, the gradients number: " << grads.size() << " is not equal to the args number: " << py_args.size() - 2 << "."; } - if (MsContext::GetInstance()->check_bprop_flag()) { + if (MsContext::GetInstance()->get_param(MS_CTX_CHECK_BPROP_FLAG)) { for (size_t i = 0; i < grads.size(); i++) { if (py::isinstance(py_args[i])) { if (!py::isinstance(grads[i])) { diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index 90f15b9e47..3be2851d4b 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -154,7 +154,7 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s DeviceAddressPtr AssignLaunchMemory(size_t size, const std::string &format, TypeId type) { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - auto device_id = ms_context->device_id(); + auto device_id = ms_context->get_param(MS_CTX_DEVICE_ID); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); MS_EXCEPTION_IF_NULL(runtime_instance); auto address_ptr = runtime_instance->AssignSingleOpLaunchMemory(size, format, type); @@ -261,11 +261,12 @@ void AscendDeviceAddress::SyncStream() const { MS_LOG(INFO) << "Start!"; auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() != kPynativeMode && !ms_context->enable_pynative_infer()) { + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode && + !ms_context->get_param(MS_CTX_ENABLE_PYNATIVE_INFER)) { MS_LOG(INFO) << "Finish!"; return; } - auto device_id = ms_context->device_id(); + auto device_id = ms_context->get_param(MS_CTX_DEVICE_ID); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); MS_EXCEPTION_IF_NULL(runtime_instance); auto ret = runtime_instance->SyncStream(); @@ -348,7 +349,7 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v } auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - auto device_id = ms_context->device_id(); + auto device_id = ms_context->get_param(MS_CTX_DEVICE_ID); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); MS_EXCEPTION_IF_NULL(runtime_instance); auto ret = @@ -475,7 +476,8 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh std::vector device_shape = GetDeviceShape(&host_shape); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() != kGraphMode && ms_context->execution_mode() != kPynativeMode && + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) != kGraphMode && + ms_context->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode && type_id_name_map.find(type_id_) != type_id_name_map.end()) { std::pair type_format = std::make_pair(type_id_name_map.at(type_id_), format_); if (use_trans_data.find(type_format) != use_trans_data.end()) { diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 7f2bde2e08..d23c1774c9 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -158,7 +158,7 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std bool AscendKernelRuntime::NeedDestroyHccl() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (!context_ptr->enable_hccl()) { + if (!context_ptr->get_param(MS_CTX_ENABLE_HCCL)) { MS_LOG(INFO) << "Hccl is not enabled"; return false; } @@ -177,7 +177,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - auto ret = rtSetDevice(context_ptr->device_id()); + auto ret = rtSetDevice(context_ptr->get_param(MS_CTX_DEVICE_ID)); if (ret != RT_ERROR_NONE) { MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast(ret) << "]"; } @@ -461,12 +461,12 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id(); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - bool is_task_sink = context_ptr->enable_task_sink(); + bool is_task_sink = context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK); if (!is_task_sink) { return true; } #ifdef MEM_REUSE_DEBUG - if (!context_ptr->enable_mem_reuse()) { + if (!context_ptr->get_param(MS_CTX_ENABLE_MEM_REUSE)) { // Get normal graph ir for memreuse mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph); } @@ -518,7 +518,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id(); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - bool is_task_sink = context_ptr->enable_task_sink(); + bool is_task_sink = context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK); if (!is_task_sink) { return true; } @@ -658,7 +658,7 @@ bool AscendKernelRuntime::InitDevice() { MS_LOG(ERROR) << "Get MsContext instance failed"; return false; } - if (context_ptr->enable_hccl()) { + if (context_ptr->get_param(MS_CTX_ENABLE_HCCL)) { if (!HcclInit()) { MS_LOG(ERROR) << "HcclInit init failed"; return false; @@ -746,7 +746,7 @@ bool AscendKernelRuntime::DestroyHccl() { return false; } MS_LOG(INFO) << "Hccl destroy successful, status = " << res << "."; - context_ptr->set_enable_hccl(false); + context_ptr->set_param(MS_CTX_ENABLE_HCCL, false); return true; } diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc index a4278f4b03..93e38278f1 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc @@ -43,7 +43,7 @@ void AscendMemoryManager::MallocDeviceMemory() { uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); - auto variable_memory_max_size = context->variable_memory_max_size(); + auto variable_memory_max_size = context->get_param(MS_CTX_VARIABLE_MEMORY_MAX_SIZE); if (variable_memory_max_size == "0") { return 0; } diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc index 6007ebc84e..65f11e0bc7 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -1373,7 +1373,7 @@ vector::iterator AscendStreamAssign::FindTargetOp(vector::it bool AscendStreamAssign::IsTaskSink() { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (!ms_context->enable_task_sink()) { + if (!ms_context->get_param(MS_CTX_ENABLE_TASK_SINK)) { MS_LOG(INFO) << "Task sink mode is not enable"; return false; } else { diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc index 65216a83c7..c40322e3bb 100644 --- a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc +++ b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc @@ -117,7 +117,7 @@ void DataDumper::SetOpMappingInfo(NotNull dump_inf if (!dump_path.has_value()) { MS_LOG(EXCEPTION) << "Dump path invalid"; } - auto device_id = context_ptr->device_id(); + auto device_id = context_ptr->get_param(MS_CTX_DEVICE_ID); dump_info->set_dump_path("/" + dump_path.value() + "_" + std::to_string(device_id) + "/"); MS_LOG(INFO) << "[DataDump] dump_path:" << dump_path.value(); diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index a96a3011b6..9189741add 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -363,7 +363,7 @@ void PrecisionReduce(const std::vector &node_mix_precision_datatype_index, *precision_reduce = false; return; } - if (context_ptr->enable_reduce_precision()) { + if (context_ptr->get_param(MS_CTX_ENABLE_REDUCE_PRECISION)) { selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatype, &kernel_match_datatype_idx_copy); } diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc index 18d67d5253..b8e15c1886 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc @@ -117,7 +117,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) { } auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); - const string prof_options_str = context->profiling_options(); + const string prof_options_str = context->get_param(MS_CTX_PROFILING_OPTIONS); std::vector opts = Split(prof_options_str, ':'); if (opts.empty()) { MS_LOG(WARNING) << "Profiling is enabled, but profiling option is not set!"; diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.h b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.h index 26118bb27c..60f5e1a9dd 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.h +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.h @@ -41,7 +41,7 @@ class ProfilingManager { inline bool IsProfiling() const { auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); - return context->enable_profiling(); + return context->get_param(MS_CTX_ENABLE_PROFILING); } protected: diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc index 8bc5f258b7..e93761fcb9 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc @@ -342,12 +342,12 @@ void ProfilingUtils::ReportProfilingData(const std::vector &task_ids, auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); - TaskDescReporter task_reporter(context->device_id(), "vm.task_desc_info", ret->second); + TaskDescReporter task_reporter(context->get_param(MS_CTX_DEVICE_ID), "vm.task_desc_info", ret->second); task_reporter.set_task_ids(task_ids); task_reporter.set_stream_ids(stream_ids); task_reporter.ReportData(); - GraphDescReporter graph_reporter(context->device_id(), "vm.graph_desc_info", ret->second); + GraphDescReporter graph_reporter(context->get_param(MS_CTX_DEVICE_ID), "vm.graph_desc_info", ret->second); graph_profiling_cnode_.erase(ret); graph_reporter.ReportData(); @@ -357,7 +357,7 @@ void ProfilingUtils::ReportProfilingData(const std::vector &task_ids, MS_LOG(ERROR) << "Graph id not found in graph_point"; return; } - PointReporter point_reporter(context->device_id(), "vm.point"); + PointReporter point_reporter(context->get_param(MS_CTX_DEVICE_ID), "vm.point"); for (const auto &point : point_iter->second) { point_reporter.AddReportData(point); } diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index 5602b86d1f..4fae7d9058 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -416,7 +416,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { mem_manager_->ResetDynamicMemory(); AssignStaticMemoryInput(graph); AssignStaticMemoryValueNode(graph); - bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); + bool is_enable_dynamic_mem = context_ptr->get_param(MS_CTX_ENABLE_DYNAMIC_MEM_POOL); if (is_enable_dynamic_mem) { // Use the dynamic memory pool. InitKernelRefCount(graph); @@ -435,8 +435,8 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) { bool ret = true; auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); - bool is_enable_pynative_infer = context_ptr->enable_pynative_infer(); + bool is_enable_dynamic_mem = context_ptr->get_param(MS_CTX_ENABLE_DYNAMIC_MEM_POOL); + bool is_enable_pynative_infer = context_ptr->get_param(MS_CTX_ENABLE_PYNATIVE_INFER); if (is_enable_dynamic_mem && !is_enable_pynative_infer) { auto graph_id = graph->graph_id(); auto iter = mem_swap_map_.find(graph_id); diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc index 746aeda2cd..2a8bc5cc0d 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc @@ -29,7 +29,7 @@ bool GPUMemoryAllocator::Init() { size_t free_size = CudaDriver::free_mem_size(); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - limited_device_memory_ = context_ptr->max_device_memory(); + limited_device_memory_ = context_ptr->get_param(MS_CTX_MAX_DEVICE_MEMORY); available_device_memory_ = FloatToSize(limited_device_memory_ * 1024 * 1024 * 1024); if (total_size > 0 && free_size > 0 && available_device_memory_ > 0) { MS_LOG(INFO) << "GPU device total memory size " << total_size << ", current free memory size " << free_size @@ -44,7 +44,7 @@ bool GPUMemoryAllocator::Init() { void GPUMemoryAllocator::CheckMaxDeviceMemory() const { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - auto max_device_memory = context_ptr->max_device_memory(); + auto max_device_memory = context_ptr->get_param(MS_CTX_MAX_DEVICE_MEMORY); // Currently not support modifying the max device memory. if (limited_device_memory_ != max_device_memory) { MS_LOG(EXCEPTION) diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc index 7702234ccd..eab93dd3e8 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc @@ -37,7 +37,7 @@ void GPUMemoryManager::MallocDeviceMemory() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); // If use the dynamic memory pool, then alloc the first memory block to init. - if (context_ptr->enable_dynamic_mem_pool()) { + if (context_ptr->get_param(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) { auto device_addr = MallocMemFromMemPool(1); if (!device_addr) { MS_LOG(EXCEPTION) << "Dynamic memory pool init error."; @@ -65,7 +65,7 @@ void GPUMemoryManager::FreeDeviceMemory() { uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_dynamic_mem_pool()) { + if (context_ptr->get_param(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) { auto device_ptr = MallocMemFromMemPool(size); MS_EXCEPTION_IF_NULL(device_ptr); return AddressOffset(device_ptr, 0); diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc index b2801974c8..1b16fafa03 100644 --- a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc @@ -162,7 +162,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector &inputs_type) { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() == kPynativeMode) { + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { return false; } if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) { diff --git a/mindspore/ccsrc/runtime/device/kernel_adjust.cc b/mindspore/ccsrc/runtime/device/kernel_adjust.cc index d0df4f1141..ca3e677398 100644 --- a/mindspore/ccsrc/runtime/device/kernel_adjust.cc +++ b/mindspore/ccsrc/runtime/device/kernel_adjust.cc @@ -60,8 +60,8 @@ void KernelAdjust::ReorderGetNext(const std::shared_ptr &k bool KernelAdjust::NeedInsertSwitch() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - return (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && - ConfigManager::GetInstance().iter_num() > 1); + return (context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK) && + context_ptr->get_param(MS_CTX_ENABLE_LOOP_SINK) && ConfigManager::GetInstance().iter_num() > 1); } CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index e89f9bf8e6..1fbfabb10b 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -50,7 +50,7 @@ bool KernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) { struct timeval start_time, end_time; (void)gettimeofday(&start_time, nullptr); #endif - bool is_task_sink = context_ptr->enable_task_sink(); + bool is_task_sink = context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK); if (is_task_sink) { ret = RunTask(graph); } else { @@ -502,7 +502,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode MS_LOG(INFO) << "communication op addr exist"; continue; } - if (context_ptr->enable_hccl()) { + if (context_ptr->get_param(MS_CTX_ENABLE_HCCL)) { mem_size = mem_manager_->GetCommonAlignSize(mem_size); } total_size += mem_size; @@ -646,7 +646,8 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const DeviceAddressPtr address = nullptr; address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id); MS_EXCEPTION_IF_NULL(address); - if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, node_size)) { + if (ms_context->get_param(MS_CTX_ENABLE_PYNATIVE_INFER) && + !mem_manager_->MallocMemFromMemPool(address, node_size)) { MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << node_size; } else if (mem_manager_->MallocMem(kStaticMem, node_size, address) == nullptr) { MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size; @@ -682,7 +683,8 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { DeviceAddressPtr address = nullptr; address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); MS_EXCEPTION_IF_NULL(address); - if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) { + if (ms_context->get_param(MS_CTX_ENABLE_PYNATIVE_INFER) && + !mem_manager_->MallocMemFromMemPool(address, tensor_size)) { MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << tensor_size; } else if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) { MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; @@ -701,7 +703,7 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(mem_manager_); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - bool is_enable_mem_reuse = context_ptr->enable_mem_reuse(); + bool is_enable_mem_reuse = context_ptr->get_param(MS_CTX_ENABLE_MEM_REUSE); auto mem_type = kDynamicMem; if (is_enable_mem_reuse) { mem_manager_->MallocReusedDynamicMem(graph); diff --git a/mindspore/ccsrc/runtime/device/memory_manager.cc b/mindspore/ccsrc/runtime/device/memory_manager.cc index cd3cd620b5..99ce1d021e 100644 --- a/mindspore/ccsrc/runtime/device/memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/memory_manager.cc @@ -54,7 +54,7 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, Me uint8_t *ptr = nullptr; if (AnfAlgo::IsCommunicationOp(node)) { bool communication_mem = false; - if (context_ptr->enable_hccl()) { + if (context_ptr->get_param(MS_CTX_ENABLE_HCCL)) { communication_mem = true; } if (type == kStaticMem) { diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc index 86f24c92e1..79b857b18f 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.cc +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -1070,7 +1070,7 @@ void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vectorToString() + "_ge_graph.dot"; - if (MsContext::GetInstance()->save_graphs_flag()) { + if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { convertor.DrawComputeGraph(name); } branches_map_[node.get()] = *(convertor.df_graph_); diff --git a/mindspore/ccsrc/utils/context/context_extends.cc b/mindspore/ccsrc/utils/context/context_extends.cc index 21d1dabbbf..efc1d3d85a 100644 --- a/mindspore/ccsrc/utils/context/context_extends.cc +++ b/mindspore/ccsrc/utils/context/context_extends.cc @@ -41,13 +41,13 @@ bool OpenTsd(const std::shared_ptr &ms_context_ptr) { MS_LOG(EXCEPTION) << "nullptr"; } - if (ms_context_ptr->is_pynative_ge_init()) { + if (ms_context_ptr->get_param(MS_CTX_IS_PYNATIVE_GE_INIT)) { return true; } - if (ms_context_ptr->tsd_ref()) { + if (ms_context_ptr->get_param(MS_CTX_TSD_REF)) { MS_LOG(DEBUG) << "TDT Dataset client is already opened."; - ms_context_ptr->set_tsd_ref("++"); + ms_context_ptr->increase_param(MS_CTX_TSD_REF); return true; } @@ -59,7 +59,7 @@ bool OpenTsd(const std::shared_ptr &ms_context_ptr) { unsigned int device_id; unsigned int rank_size = 1; - device_id = ms_context_ptr->device_id(); + device_id = ms_context_ptr->get_param(MS_CTX_DEVICE_ID); auto rank_size_env = common::GetEnv("RANK_SIZE"); if (rank_size_env.empty()) { @@ -79,7 +79,7 @@ bool OpenTsd(const std::shared_ptr &ms_context_ptr) { MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << "."; return false; } - ms_context_ptr->set_tsd_ref("++"); + ms_context_ptr->increase_param(MS_CTX_TSD_REF); #ifdef ENABLE_TDTQUE int32_t initStatus = tdt::TdtHostInit(device_id); if (initStatus != TDT_OK_CODE) { @@ -88,7 +88,8 @@ bool OpenTsd(const std::shared_ptr &ms_context_ptr) { } ms_context_ptr->tdt_print_ = std::thread(TensorPrint()); #endif - MS_LOG(INFO) << "Open and init tsd successful, tsd reference = " << ms_context_ptr->tsd_ref() << "."; + MS_LOG(INFO) << "Open and init tsd successful, tsd reference = " + << ms_context_ptr->get_param(MS_CTX_TSD_REF) << "."; return true; } @@ -96,12 +97,12 @@ bool CloseTsd(const std::shared_ptr &ms_context_ptr, bool force) { if (ms_context_ptr == nullptr) { MS_LOG(EXCEPTION) << "nullptr"; } - if (ms_context_ptr->tsd_ref() == 0) { + if (ms_context_ptr->get_param(MS_CTX_TSD_REF) == 0) { return true; } - ms_context_ptr->set_tsd_ref("--"); - if (force || ms_context_ptr->tsd_ref() == 0) { - ms_context_ptr->set_tsd_ref(" "); + ms_context_ptr->decrease_param(MS_CTX_TSD_REF); + if (force || ms_context_ptr->get_param(MS_CTX_TSD_REF) == 0) { + ms_context_ptr->set_param(MS_CTX_TSD_REF, 0); #ifdef ENABLE_TDTQUE int32_t stopStatus = tdt::TdtHostStop(KNpuLog); if (stopStatus != TDT_OK_CODE) { @@ -123,17 +124,17 @@ bool CloseTsd(const std::shared_ptr &ms_context_ptr, bool force) { MS_LOG(ERROR) << "tdt thread join failed: " << e.what(); } #endif - auto device_id = ms_context_ptr->device_id(); + auto device_id = ms_context_ptr->get_param(MS_CTX_DEVICE_ID); TDT_StatusT status = TsdClose(device_id); if (status != TDT_OK) { MS_LOG(EXCEPTION) << "Close tsd failed, status = " << status << "."; return false; } - ms_context_ptr->set_pynative_ge_init(false); + ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, false); MS_LOG(INFO) << "Destroy and close tsd successful, status = " << status << "."; } else { - MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = " << ms_context_ptr->tsd_ref() - << "."; + MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = " + << ms_context_ptr->get_param(MS_CTX_TSD_REF) << "."; } return true; @@ -159,14 +160,14 @@ void GetGeOptions(const std::shared_ptr &ms_context_ptr, std::mapenable_dump()); - (*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->save_dump_path(); + (*ge_options)["ge.exec.enableDump"] = std::to_string(ms_context_ptr->get_param(MS_CTX_ENABLE_DUMP)); + (*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->get_param(MS_CTX_SAVE_DUMP_PATH); (*ge_options)["ge.exec.dumpMode"] = "output"; - MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->enable_dump()) - << " and save dump path is " << ms_context_ptr->save_dump_path() << "."; - (*ge_options)["ge.exec.profilingMode"] = std::to_string(ms_context_ptr->enable_profiling()); - if (ms_context_ptr->enable_profiling()) { - (*ge_options)["ge.exec.profilingOptions"] = ms_context_ptr->profiling_options(); + MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->get_param(MS_CTX_ENABLE_DUMP)) + << " and save dump path is " << ms_context_ptr->get_param(MS_CTX_SAVE_DUMP_PATH) << "."; + (*ge_options)["ge.exec.profilingMode"] = std::to_string(ms_context_ptr->get_param(MS_CTX_ENABLE_PROFILING)); + if (ms_context_ptr->get_param(MS_CTX_ENABLE_PROFILING)) { + (*ge_options)["ge.exec.profilingOptions"] = ms_context_ptr->get_param(MS_CTX_PROFILING_OPTIONS); } (*ge_options)["rank_table_file"] = ""; @@ -178,12 +179,12 @@ void GetGeOptions(const std::shared_ptr &ms_context_ptr, std::mapgraph_memory_max_size() != "0") { - (*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->graph_memory_max_size(); + if (ms_context_ptr->get_param(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") { + (*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param(MS_CTX_GRAPH_MEMORY_MAX_SIZE); } - if (ms_context_ptr->variable_memory_max_size() != "0") { - (*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->variable_memory_max_size(); + if (ms_context_ptr->get_param(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") { + (*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param(MS_CTX_VARIABLE_MEMORY_MAX_SIZE); } #if ENABLE_TRAIN == 1 @@ -224,7 +225,7 @@ void GetGeOptions(const std::shared_ptr &ms_context_ptr, std::mapauto_mixed_precision_flag()) { + if (ms_context_ptr->get_param(MS_CTX_AUTO_MIXED_PRECISION_FLAG)) { (*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision"; } else { (*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16"; @@ -240,7 +241,7 @@ void SetHcclOptions(const std::shared_ptr &ms_context_ptr, std::mapdevice_id()); + auto env_device_id = std::to_string(ms_context_ptr->get_param(MS_CTX_DEVICE_ID)); if (!(env_table_file.empty() || env_rank_id.empty())) { MS_LOG(INFO) << "Initialize Ge for distribute parameter"; MS_LOG(INFO) << "Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH."; @@ -275,12 +276,12 @@ bool InitGe(const std::shared_ptr &ms_context_ptr) { MS_LOG(EXCEPTION) << "nullptr"; } #ifdef ENABLE_GE - if (ms_context_ptr->is_pynative_ge_init()) { + if (ms_context_ptr->get_param(MS_CTX_IS_PYNATIVE_GE_INIT)) { return true; } - if (ms_context_ptr->ge_ref()) { - ms_context_ptr->set_ge_ref("++"); + if (ms_context_ptr->get_param(MS_CTX_GE_REF)) { + ms_context_ptr->increase_param(MS_CTX_GE_REF); return true; } @@ -293,8 +294,8 @@ bool InitGe(const std::shared_ptr &ms_context_ptr) { MS_LOG(EXCEPTION) << "Initialize GE failed!"; } } - ms_context_ptr->set_ge_ref("++"); - MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->ge_ref() << "."; + ms_context_ptr->increase_param(MS_CTX_GE_REF); + MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->get_param(MS_CTX_GE_REF) << "."; #endif return true; } @@ -303,12 +304,13 @@ bool PynativeInitGe(const std::shared_ptr &ms_context_ptr) { if (ms_context_ptr == nullptr) { MS_LOG(EXCEPTION) << "nullptr"; } - if (ms_context_ptr->is_pynative_ge_init() || ms_context_ptr->ge_ref() || ms_context_ptr->tsd_ref()) { + if (ms_context_ptr->get_param(MS_CTX_IS_PYNATIVE_GE_INIT) || + ms_context_ptr->get_param(MS_CTX_GE_REF) || ms_context_ptr->get_param(MS_CTX_TSD_REF)) { return true; } (void)OpenTsd(ms_context_ptr); (void)InitGe(ms_context_ptr); - ms_context_ptr->set_pynative_ge_init(true); + ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, true); return true; } @@ -317,12 +319,12 @@ bool FinalizeGe(const std::shared_ptr &ms_context_ptr, bool force) { MS_LOG(EXCEPTION) << "nullptr"; } #ifdef ENABLE_GE - if (ms_context_ptr->ge_ref() == 0) { + if (ms_context_ptr->get_param(MS_CTX_GE_REF) == 0) { return true; } - ms_context_ptr->set_ge_ref("--"); - if (force || ms_context_ptr->ge_ref() == 0) { - ms_context_ptr->set_ge_ref(" "); + ms_context_ptr->decrease_param(MS_CTX_GE_REF); + if (force || ms_context_ptr->get_param(MS_CTX_GE_REF) == 0) { + ms_context_ptr->set_param(MS_CTX_GE_REF, 0); try { DfGraphManager::GetInstance().DeleteGraphRunner(); DfGraphManager::GetInstance().DeleteGeSession(); @@ -337,7 +339,8 @@ bool FinalizeGe(const std::shared_ptr &ms_context_ptr, bool force) { } ms_context_ptr->set_pynative_ge_init(false); } else { - MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = " << ms_context_ptr->ge_ref() << "."; + MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = " + << ms_context_ptr->get_param(MS_CTX_GE_REF) << "."; } #endif return true; @@ -347,14 +350,14 @@ bool IsTsdOpened(const std::shared_ptr &ms_context_ptr) { if (ms_context_ptr == nullptr) { MS_LOG(EXCEPTION) << "nullptr"; } - return ms_context_ptr->IsTsdOpened(); + return ms_context_ptr->get_param(MS_CTX_TSD_REF) > 0; } bool IsGeInited(const std::shared_ptr &ms_context_ptr) { if (ms_context_ptr == nullptr) { MS_LOG(EXCEPTION) << "nullptr"; } - return ms_context_ptr->IsGeInited(); + return ms_context_ptr->get_param(MS_CTX_GE_REF) > 0; } // Register for device type. diff --git a/mindspore/ccsrc/utils/convert_utils_py.cc b/mindspore/ccsrc/utils/convert_utils_py.cc index 5fd7b2b8a9..d09d4485e6 100644 --- a/mindspore/ccsrc/utils/convert_utils_py.cc +++ b/mindspore/ccsrc/utils/convert_utils_py.cc @@ -353,7 +353,7 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py // When sparse enabled, the undetermined might be raised and eliminated in opt passes auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); - bool enable_sparse = context->enable_sparse(); + bool enable_sparse = context->get_param(MS_CTX_ENABLE_SPARSE); if (enable_sparse) { return std::make_shared(); } diff --git a/mindspore/ccsrc/utils/tensorprint_utils.cc b/mindspore/ccsrc/utils/tensorprint_utils.cc index bdb06d7fb3..18116909ae 100644 --- a/mindspore/ccsrc/utils/tensorprint_utils.cc +++ b/mindspore/ccsrc/utils/tensorprint_utils.cc @@ -273,7 +273,7 @@ void TensorPrint::operator()() { prntpb::Print print; auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - std::string print_file_path = ms_context->print_file_path(); + std::string print_file_path = ms_context->get_param(MS_CTX_PRINT_FILE_PATH); if (print_file_path == "") { while (true) { std::vector bundle; diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 52200d1ecc..0b8bea0a24 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -59,7 +59,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri graph_id = target_sess_->CompileGraphAsync(lst, outputs); } - if (MsContext::GetInstance()->precompile_only()) { + if (MsContext::GetInstance()->get_param(MS_CTX_PRECOMPILE_ONLY)) { MS_LOG(INFO) << "PrecompileOnly, stop run graph"; return result; } @@ -180,7 +180,7 @@ void MsBackend::CreateOtherSession(const std::string &target) { } auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - uint32_t device_id = context_ptr->device_id(); + uint32_t device_id = context_ptr->get_param(MS_CTX_DEVICE_ID); other_sess_->Init(device_id); other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); other_device_ = target; diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index fb1acfd713..048077d39d 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -56,7 +56,7 @@ namespace { bool ContainMultiTarget(const std::vector &nodes) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - std::string last_target = context_ptr->device_target(); + std::string last_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); for (auto &node : nodes) { if (node->isa()) { std::string cur_target = GetCNodeTarget(node); @@ -348,7 +348,7 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { if (prim->name() == prim::kPrimBpropCut->name()) { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - ms_context->set_enable_pynative_hook(true); + ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_HOOK, true); } if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) { @@ -412,7 +412,7 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { if (ContainMultiTarget(nodes)) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - std::string default_target = context_ptr->device_target(); + std::string default_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); nodes = SplitSort(graph, default_target); return SplitNodesWithTarget(nodes, graph); } @@ -920,17 +920,17 @@ BackendPtr CreateBackend() { } if (name == kMsConvert) { - std::string target = context_ptr->device_target(); - uint32_t device_id = context_ptr->device_id(); + std::string target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); + uint32_t device_id = context_ptr->get_param(MS_CTX_DEVICE_ID); auto backend = std::make_shared(name, target, device_id); - std::string device_target = MsContext::GetInstance()->device_target(); + std::string device_target = MsContext::GetInstance()->get_param(MS_CTX_DEVICE_TARGET); if (device_target == kAscendDevice) { - if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { + if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { backend->set_is_multi_graph_sink(false); - context_ptr->set_is_multi_graph_sink(false); + context_ptr->set_param(MS_CTX_IS_MULTI_GRAPH_SINK, false); } else { backend->set_is_multi_graph_sink(true); - context_ptr->set_is_multi_graph_sink(true); + context_ptr->set_param(MS_CTX_IS_MULTI_GRAPH_SINK, true); } } return backend; diff --git a/mindspore/context.py b/mindspore/context.py index 361c48c09b..7788a4f367 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -22,7 +22,7 @@ import threading from collections import namedtuple from types import FunctionType from mindspore import log as logger -from mindspore._c_expression import MSContext +from mindspore._c_expression import MSContext, ms_ctx_param, ms_ctx_get_param, ms_ctx_set_param from mindspore._checkparam import args_type_check from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ _reset_auto_parallel_context @@ -157,9 +157,15 @@ class _Context: raise ValueError("Context handle is none in context!!!") return value + def get_param(self, param): + return ms_ctx_get_param(self._context_handle, param) + + def set_param(self, param, value): + ms_ctx_set_param(self._context_handle, param, value) + @property def mode(self): - return self._context_handle.get_execution_mode() + return self.get_param(ms_ctx_param.execution_mode) @mode.setter def mode(self, mode): @@ -169,15 +175,17 @@ class _Context: Args: mode (int): GRAPH_MODE or PYNATIVE_MODE. """ - self._context_handle.set_execution_mode(mode) if mode == PYNATIVE_MODE: if self.enable_debug_runtime: self.set_backend_policy("vm") self._context_switches.push(True, None) - else: + elif mode == GRAPH_MODE: if self.enable_debug_runtime: self.set_backend_policy("ge") self._context_switches.push(False, None) + else: + raise ValueError(f'The execution mode {mode} is invalid!') + self.set_param(ms_ctx_param.execution_mode, mode) def set_backend_policy(self, policy): success = self._context_handle.set_backend_policy(policy) @@ -186,110 +194,106 @@ class _Context: @property def precompile_only(self): - return self._context_handle.get_precompile_only() + return self.get_param(ms_ctx_param.precompile_only) @precompile_only.setter def precompile_only(self, precompile_only): - self._context_handle.set_precompile_only(precompile_only) + self.set_param(ms_ctx_param.precompile_only, precompile_only) @property def save_graphs(self): - return self._context_handle.get_save_graphs_flag() + return self.get_param(ms_ctx_param.save_graphs_flag) @save_graphs.setter def save_graphs(self, save_graphs_flag): - self._context_handle.set_save_graphs_flag(save_graphs_flag) + self.set_param(ms_ctx_param.save_graphs_flag, save_graphs_flag) @property def save_graphs_path(self): - return self._context_handle.get_save_graphs_path() + return self.get_param(ms_ctx_param.save_graphs_path) @save_graphs_path.setter def save_graphs_path(self, save_graphs_path): - self._context_handle.set_save_graphs_path( - _make_directory(save_graphs_path)) + self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path)) @property def device_target(self): - return self._context_handle.get_device_target() + return self.get_param(ms_ctx_param.device_target) @device_target.setter def device_target(self, target): - success = self._context_handle.set_device_target(target) - if not success: - raise ValueError("Target device name is invalid!!!") - if self.enable_debug_runtime and self.device_target == "CPU": + valid_targets = ["CPU", "GPU", "Ascend", "Davinci"] + if not target in valid_targets: + raise ValueError(f"Target device name {target} is invalid! It must be one of {valid_targets}") + if target == "Davinci": + target = "Ascend" + self.set_param(ms_ctx_param.device_target, target) + if self.enable_debug_runtime and target == "CPU": self.set_backend_policy("vm") @property def device_id(self): - return self._context_handle.get_device_id() + return self.get_param(ms_ctx_param.device_id) @device_id.setter def device_id(self, device_id): if device_id < 0 or device_id > 4095: - raise ValueError( - "Device id must be in [0, 4095], but got {}".format(device_id)) - success = self._context_handle.set_device_id(device_id) - if not success: - raise RuntimeError("Device id set failed!!!") + raise ValueError(f"Device id must be in [0, 4095], but got {device_id}") + self.set_param(ms_ctx_param.device_id, device_id) @property def max_call_depth(self): - return self._context_handle.get_max_call_depth() + return self.get_param(ms_ctx_param.max_call_depth) @max_call_depth.setter def max_call_depth(self, max_call_depth): if max_call_depth <= 0: - raise ValueError( - "Max call depth must be greater than 0, but got {}".format(max_call_depth)) - self._context_handle.set_max_call_depth(max_call_depth) + raise ValueError(f"Max call depth must be greater than 0, but got {max_call_depth}") + self.set_param(ms_ctx_param.max_call_depth, max_call_depth) @property def enable_auto_mixed_precision(self): - return self._context_handle.get_auto_mixed_precision_flag() + return self.get_param(ms_ctx_param.auto_mixed_precision_flag) @enable_auto_mixed_precision.setter def enable_auto_mixed_precision(self, enable_auto_mixed_precision): - self._context_handle.set_auto_mixed_precision_flag( - enable_auto_mixed_precision) + self.set_param(ms_ctx_param.auto_mixed_precision_flag, enable_auto_mixed_precision) @property def enable_reduce_precision(self): - return self._context_handle.get_enable_reduce_precision_flag() + return self.get_param(ms_ctx_param.enable_reduce_precision_flag) @enable_reduce_precision.setter def enable_reduce_precision(self, enable_reduce_precision): - self._context_handle.set_enable_reduce_precision_flag( - enable_reduce_precision) + self.set_param(ms_ctx_param.enable_reduce_precision_flag, enable_reduce_precision) @property def enable_dump(self): - return self._context_handle.get_enable_dump() + return self.get_param(ms_ctx_param.enable_dump) @enable_dump.setter def enable_dump(self, enable_dump): - self._context_handle.set_enable_dump(enable_dump) + self.set_param(ms_ctx_param.enable_dump, enable_dump) @property def save_dump_path(self): - return self._context_handle.get_save_dump_path() + return self.get_param(ms_ctx_param.save_dump_path) @save_dump_path.setter def save_dump_path(self, save_dump_path): - self._context_handle.set_save_dump_path(save_dump_path) + self.set_param(ms_ctx_param.save_dump_path, save_dump_path) @property def enable_profiling(self): - return self._context_handle.get_enable_profiling() + return self.get_param(ms_ctx_param.enable_profiling) @enable_profiling.setter def enable_profiling(self, flag): - self._context_handle.set_enable_profiling(flag) + self.set_param(ms_ctx_param.enable_profiling, flag) @property def profiling_options(self): - return self._context_handle.get_profiling_options() + return self.get_param(ms_ctx_param.profiling_options) @profiling_options.setter def profiling_options(self, option): @@ -298,15 +302,15 @@ class _Context: if option not in options: raise ValueError("Profiling options must be in 'training_trace' 'task_trace' " "'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'.") - self._context_handle.set_profiling_options(option) + self.set_param(ms_ctx_param.profiling_options, option) @property def enable_graph_kernel(self): - return self._context_handle.get_enable_graph_kernel() + return self.get_param(ms_ctx_param.enable_graph_kernel) @enable_graph_kernel.setter def enable_graph_kernel(self, graph_kernel_switch_): - self._context_handle.set_enable_graph_kernel(graph_kernel_switch_) + self.set_param(ms_ctx_param.enable_graph_kernel, graph_kernel_switch_) @property def reserve_class_name_in_scope(self): @@ -325,20 +329,14 @@ class _Context: @variable_memory_max_size.setter def variable_memory_max_size(self, variable_memory_max_size): if not check_input_format(variable_memory_max_size): - raise ValueError( - "Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") + raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: - raise ValueError( - "Context param variable_memory_max_size should be less than 31GB.") - variable_memory_max_size_ = variable_memory_max_size[:- - 2] + " * 1024 * 1024 * 1024" - graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - \ - int(variable_memory_max_size[:-2]) - graph_memory_max_size_ = str( - graph_memory_max_size) + " * 1024 * 1024 * 1024" - self._context_handle.set_variable_memory_max_size( - variable_memory_max_size_) - self._context_handle.set_graph_memory_max_size(graph_memory_max_size_) + raise ValueError("Context param variable_memory_max_size should be less than 31GB.") + variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024" + graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2]) + graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024" + self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_) + self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_) @property def enable_ge(self): @@ -355,15 +353,15 @@ class _Context: @property def check_bprop(self): - return self._context_handle.get_check_bprop_flag() + return self.get_param(ms_ctx_param.check_bprop_flag) @check_bprop.setter def check_bprop(self, check_bprop_flag): - self._context_handle.set_check_bprop_flag(check_bprop_flag) + self.set_param(ms_ctx_param.check_bprop_flag, check_bprop_flag) @property def max_device_memory(self): - return self._context_handle.get_max_device_memory() + return self.get_param(ms_ctx_param.max_device_memory) @max_device_memory.setter def max_device_memory(self, max_device_memory): @@ -372,7 +370,7 @@ class _Context: max_device_memory_value = float(max_device_memory[:-2]) if max_device_memory_value == 0: raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") - self._context_handle.set_max_device_memory(max_device_memory_value) + self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value) @property def print_file_path(self): @@ -392,15 +390,15 @@ class _Context: full_file_name = os.path.join(path, file_name) else: full_file_name = print_file_path - self._context_handle.set_print_file_path(full_file_name) + self.set_param(ms_ctx_param.print_file_path, full_file_name) @property def enable_sparse(self): - return self._context_handle.get_enable_sparse() + return self.get_param(ms_ctx_param.enable_sparse) @enable_sparse.setter def enable_sparse(self, enable_sparse): - self._context_handle.set_enable_sparse(enable_sparse) + self.set_param(ms_ctx_param.enable_sparse, enable_sparse) def check_input_format(x): import re @@ -486,8 +484,6 @@ def set_auto_parallel_context(**kwargs): full_batch (bool): Whether to load the whole batch on each device. Default: False. enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in data parallel training in the benefit of time and memory saving. - max_call_depth(int): Specify the function call depth limit. Default: 1000. - Raises: ValueError: If input key is not attribute in auto parallel context. @@ -501,7 +497,6 @@ def set_auto_parallel_context(**kwargs): >>> context.set_auto_parallel_context(parameter_broadcast=False) >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") - >>> context.set_auto_parallel_context(max_call_depth=80) """ _set_auto_parallel_context(**kwargs) @@ -603,6 +598,7 @@ def set_context(**kwargs): a file by default, and turn off printing to the screen. If the file already exists, add a timestamp suffix to the file. enable_sparse (bool): Whether to enable sparsity feature. Default: False. + max_call_depth(int): Specify the function call depth limit. Default: 1000. Raises: ValueError: If input key is not an attribute in context. @@ -623,6 +619,7 @@ def set_context(**kwargs): >>> context.set_context(enable_profiling=True, profiling_options="training_trace") >>> context.set_context(max_device_memory="3.5GB") >>> context.set_context(print_file_path="print.pb") + >>> context.set_context(max_call_depth=80) """ for key, value in kwargs.items(): if not hasattr(_context(), key): diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 448d4dad90..868c50aa7c 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -51,7 +51,7 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); - bool enable_sparse = context->enable_sparse(); + bool enable_sparse = context->get_param(MS_CTX_ENABLE_SPARSE); if (enable_sparse && dflt->isa()) { auto dflt_tensor = dflt->cast(); return std::make_shared(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone()); diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index 9394a2792d..0446ab988d 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -232,7 +232,7 @@ std::string GetMaketupleNodeTarget(const CNodePtr &cnode) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - std::string default_target = context_ptr->device_target(); + std::string default_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); return default_target; } @@ -248,7 +248,7 @@ std::string GetTupleGetItemTarget(const CNodePtr &cnode, const PrimitivePtr &pri std::string GetCNodeTarget(const AnfNodePtr &node) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - std::string default_target = context_ptr->device_target(); + std::string default_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); if (!node->isa()) { return default_target; } diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index f47dba7573..71673b362d 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -652,7 +652,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP new_func_graph->set_param_default_value(item.first, cloner[item.second]); } - if (MsContext::GetInstance()->is_multi_graph_sink()) { + if (MsContext::GetInstance()->get_param(MS_CTX_IS_MULTI_GRAPH_SINK)) { if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); } diff --git a/mindspore/core/ir/meta_func_graph.cc b/mindspore/core/ir/meta_func_graph.cc index 9aefcac52c..d89ef3ab9a 100644 --- a/mindspore/core/ir/meta_func_graph.cc +++ b/mindspore/core/ir/meta_func_graph.cc @@ -30,7 +30,7 @@ abstract::AbstractBasePtr MetaFuncGraph::ToAbstract() { FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) { auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); - bool enable_sparse = context->enable_sparse(); + bool enable_sparse = context->get_param(MS_CTX_ENABLE_SPARSE); if (!enable_sparse) { return nullptr; } diff --git a/mindspore/core/utils/ms_context.cc b/mindspore/core/utils/ms_context.cc index 42453d9904..064a76b065 100644 --- a/mindspore/core/utils/ms_context.cc +++ b/mindspore/core/utils/ms_context.cc @@ -32,49 +32,50 @@ std::map MsContext::policy_map_ = {{"ge", kMsBacke {"vm_prior", kMsBackendVmPrior}}; MsContext::MsContext(const std::string &policy, const std::string &target) { - save_graphs_flag_ = false; - save_graphs_path_ = "."; - enable_dump_ = false; - save_dump_path_ = "."; - tsd_ref_ = 0; - ge_ref_ = 0; - is_multi_graph_sink_ = false; - is_pynative_ge_init_ = false; - enable_reduce_precision_ = true; + set_param(MS_CTX_SAVE_GRAPHS_FLAG, false); + set_param(MS_CTX_SAVE_GRAPHS_PATH, "."); + set_param(MS_CTX_SAVE_DUMP_PATH, "."); + set_param(MS_CTX_TSD_REF, 0); + set_param(MS_CTX_GE_REF, 0); + set_param(MS_CTX_IS_MULTI_GRAPH_SINK, false); + set_param(MS_CTX_IS_PYNATIVE_GE_INIT, false); + set_param(MS_CTX_ENABLE_REDUCE_PRECISION, true); auto env_device = common::GetEnv("DEVICE_ID"); if (!env_device.empty()) { - device_id_ = UlongToUint(std::stoul(env_device.c_str())); + uint32_t device_id = UlongToUint(std::stoul(env_device.c_str())); + set_param(MS_CTX_DEVICE_ID, device_id); } else { - device_id_ = 0; + set_param(MS_CTX_DEVICE_ID, 0); } - max_call_depth_ = MAX_CALL_DEPTH_DEFAULT; - backend_policy_ = policy_map_[policy]; - device_target_ = target; - execution_mode_ = kPynativeMode; - enable_task_sink_ = true; - ir_fusion_flag_ = true; - enable_hccl_ = false; + set_param(MS_CTX_MAX_CALL_DEPTH, MAX_CALL_DEPTH_DEFAULT); + set_param(MS_CTX_DEVICE_TARGET, target); + set_param(MS_CTX_EXECUTION_MODE, kPynativeMode); + set_param(MS_CTX_ENABLE_TASK_SINK, true); + set_param(MS_CTX_IR_FUSION_FLAG, true); + set_param(MS_CTX_ENABLE_HCCL, false); #ifdef ENABLE_DEBUGGER - enable_mem_reuse_ = false; + set_param(MS_CTX_ENABLE_MEM_REUSE, false); #else - enable_mem_reuse_ = true; + set_param(MS_CTX_ENABLE_MEM_REUSE, true); #endif - enable_gpu_summary_ = true; - precompile_only_ = false; - auto_mixed_precision_flag_ = false; - enable_pynative_infer_ = false; - enable_pynative_hook_ = false; - enable_dynamic_mem_pool_ = true; - graph_memory_max_size_ = "0"; - variable_memory_max_size_ = "0"; - enable_loop_sink_ = target == kAscendDevice || target == kDavinciDevice; - profiling_mode_ = false; - profiling_options_ = "training_trace"; - check_bprop_flag_ = false; - max_device_memory_ = kDefaultMaxDeviceMemory; - print_file_path_ = ""; - enable_graph_kernel_ = false; - enable_sparse_ = false; + set_param(MS_CTX_ENABLE_GPU_SUMMARY, true); + set_param(MS_CTX_PRECOMPILE_ONLY, false); + set_param(MS_CTX_AUTO_MIXED_PRECISION_FLAG, false); + set_param(MS_CTX_ENABLE_PYNATIVE_INFER, false); + set_param(MS_CTX_ENABLE_PYNATIVE_HOOK, false); + set_param(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true); + set_param(MS_CTX_GRAPH_MEMORY_MAX_SIZE, "0"); + set_param(MS_CTX_VARIABLE_MEMORY_MAX_SIZE, "0"); + set_param(MS_CTX_ENABLE_LOOP_SINK, target == kAscendDevice || target == kDavinciDevice); + set_param(MS_CTX_ENABLE_PROFILING, false); + set_param(MS_CTX_PROFILING_OPTIONS, "training_trace"); + set_param(MS_CTX_CHECK_BPROP_FLAG, false); + set_param(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory); + set_param(MS_CTX_PRINT_FILE_PATH, ""); + set_param(MS_CTX_ENABLE_GRAPH_KERNEL, false); + set_param(MS_CTX_ENABLE_SPARSE, false); + + backend_policy_ = policy_map_[policy]; } std::shared_ptr MsContext::GetInstance() { @@ -106,54 +107,4 @@ std::string MsContext::backend_policy() const { } return "unknown"; } - -void MsContext::set_execution_mode(int execution_mode) { - if (execution_mode != kGraphMode && execution_mode != kPynativeMode) { - MS_LOG(EXCEPTION) << "The execution mode is invalid!"; - } - execution_mode_ = execution_mode; -} - -bool MsContext::set_device_target(const std::string &target) { - if (kTargetSet.find(target) == kTargetSet.end()) { - MS_LOG(ERROR) << "invalid device target name: " << target; - return false; - } - if (target == kDavinciDevice) { - device_target_ = kAscendDevice; - } else { - device_target_ = target; - } - if (seter_) { - seter_(device_target_); - } - MS_LOG(INFO) << "ms set context device target:" << target; - return true; -} - -bool MsContext::set_device_id(uint32_t device_id) { - device_id_ = device_id; - MS_LOG(INFO) << "ms set context device id:" << device_id; - return true; -} - -void MsContext::set_tsd_ref(const std::string &op) { - if (op == "--") { - tsd_ref_--; - } else if (op == "++") { - tsd_ref_++; - } else { - tsd_ref_ = 0; - } -} - -void MsContext::set_ge_ref(const std::string &op) { - if (op == "--") { - ge_ref_--; - } else if (op == "++") { - ge_ref_++; - } else { - ge_ref_ = 0; - } -} } // namespace mindspore diff --git a/mindspore/core/utils/ms_context.h b/mindspore/core/utils/ms_context.h index 9d29ad5bdc..9d461eaa43 100644 --- a/mindspore/core/utils/ms_context.h +++ b/mindspore/core/utils/ms_context.h @@ -49,6 +49,69 @@ const std::set kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, // The default max available device memory is 1024GB. const float kDefaultMaxDeviceMemory = 1024; +// enum definition for MindSpore Context Parameter +enum MsCtxParam : unsigned { + // paramater of type bool + MS_CTX_TYPE_BOOL_BEGIN, + MS_CTX_AUTO_MIXED_PRECISION_FLAG = MS_CTX_TYPE_BOOL_BEGIN, + MS_CTX_CHECK_BPROP_FLAG, + MS_CTX_ENABLE_DUMP, + MS_CTX_ENABLE_DYNAMIC_MEM_POOL, + MS_CTX_ENABLE_GPU_SUMMARY, + MS_CTX_ENABLE_GRAPH_KERNEL, + MS_CTX_ENABLE_HCCL, + MS_CTX_ENABLE_LOOP_SINK, + MS_CTX_ENABLE_MEM_REUSE, + MS_CTX_ENABLE_PYNATIVE_HOOK, + MS_CTX_ENABLE_PYNATIVE_INFER, + MS_CTX_ENABLE_REDUCE_PRECISION, + MS_CTX_ENABLE_SPARSE, + MS_CTX_ENABLE_TASK_SINK, + MS_CTX_IR_FUSION_FLAG, + MS_CTX_IS_MULTI_GRAPH_SINK, + MS_CTX_IS_PYNATIVE_GE_INIT, + MS_CTX_PRECOMPILE_ONLY, + MS_CTX_ENABLE_PROFILING, + MS_CTX_SAVE_GRAPHS_FLAG, + MS_CTX_TYPE_BOOL_END, + + // paramater of type int + MS_CTX_TYPE_INT_BEGIN = MS_CTX_TYPE_BOOL_END, + MS_CTX_EXECUTION_MODE = MS_CTX_TYPE_INT_BEGIN, + MS_CTX_TYPE_INT_END, + + // paramater of type uint32 + MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END, + MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN, + MS_CTX_GE_REF, + MS_CTX_MAX_CALL_DEPTH, + MS_CTX_TSD_REF, + MS_CTX_TYPE_UINT32_END, + + // paramater of type float + MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END, + MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN, + MS_CTX_TYPE_FLOAT_END, + + // paramater of type string + MS_CTX_TYPE_STRING_BEGIN = MS_CTX_TYPE_FLOAT_END, + MS_CTX_DEVICE_TARGET = MS_CTX_TYPE_STRING_BEGIN, + MS_CTX_GRAPH_MEMORY_MAX_SIZE, + MS_CTX_PRINT_FILE_PATH, + MS_CTX_PROFILING_OPTIONS, + MS_CTX_SAVE_DUMP_PATH, + MS_CTX_SAVE_GRAPHS_PATH, + MS_CTX_VARIABLE_MEMORY_MAX_SIZE, + MS_CTX_TYPE_STRING_END, + + // parameter numbers of each type + NUM_BOOL_PARAMS = MS_CTX_TYPE_BOOL_END - MS_CTX_TYPE_BOOL_BEGIN, + NUM_INT_PARAMS = MS_CTX_TYPE_INT_END - MS_CTX_TYPE_INT_BEGIN, + NUM_UINT32_PARAMS = MS_CTX_TYPE_UINT32_END - MS_CTX_TYPE_UINT32_BEGIN, + NUM_FLOAT_PARAMS = MS_CTX_TYPE_FLOAT_END - MS_CTX_TYPE_FLOAT_BEGIN, + NUM_STRING_PARAMS = MS_CTX_TYPE_STRING_END - MS_CTX_TYPE_STRING_BEGIN +}; + class MsContext { public: MsContext(const std::string &backend_policy, const std::string &target); @@ -62,156 +125,113 @@ class MsContext { std::string backend_policy() const; bool set_backend_policy(const std::string &policy); - int execution_mode() const { return execution_mode_; } - void set_execution_mode(int execution_mode); - - bool enable_pynative_infer() const { return enable_pynative_infer_; } - void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; } - - bool enable_pynative_hook() const { return enable_pynative_hook_; } - void set_enable_pynative_hook(bool enable_pynative_hook) { enable_pynative_hook_ = enable_pynative_hook; } - - bool enable_task_sink() const { return enable_task_sink_; } - - void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; } - bool precompile_only() const { return precompile_only_; } - - std::string device_target() const { return device_target_; } - bool set_device_target(const std::string &target); - - uint32_t device_id() const { return device_id_; } - bool set_device_id(uint32_t device_id); - - // uint32_t max_call_depth_ - uint32_t max_call_depth() const { return max_call_depth_; } - inline bool set_max_call_depth(uint32_t max_call_depth) { - max_call_depth_ = max_call_depth; - return true; - } - - bool save_graphs_flag() const { return save_graphs_flag_; } - void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; } - - std::string save_graphs_path() const { return save_graphs_path_; } - void set_save_graphs_path(const std::string &save_paths) { save_graphs_path_ = save_paths; } - - bool IsGeInited() { return ge_ref_ > 0; } - void set_enable_hccl(bool enable_hccl) { enable_hccl_ = enable_hccl; } - bool enable_hccl() const { return enable_hccl_; } - bool ir_fusion_flag() const { return ir_fusion_flag_; } - bool loop_sink_flag() const { return enable_loop_sink_; } - void set_loop_sink_flag(bool enable_loop_sink) { enable_loop_sink_ = enable_loop_sink; } - void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; } - bool enable_mem_reuse() const { return enable_mem_reuse_; } - - void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; } - bool enable_gpu_summary() const { return enable_gpu_summary_; } - - void set_auto_mixed_precision_flag(bool auto_mixed_precision_flag) { - auto_mixed_precision_flag_ = auto_mixed_precision_flag; - } - bool auto_mixed_precision_flag() const { return auto_mixed_precision_flag_; } - - void set_enable_reduce_precision(bool flag) { enable_reduce_precision_ = flag; } - bool enable_reduce_precision() const { return enable_reduce_precision_; } - - void set_enable_dump(bool flag) { enable_dump_ = flag; } - bool enable_dump() const { return enable_dump_; } - - void set_save_dump_path(const std::string &path) { save_dump_path_ = path; } - std::string save_dump_path() const { return save_dump_path_; } - - bool IsTsdOpened() const { return tsd_ref_ > 0; } - void set_tsd_ref(const std::string &op); - uint32_t tsd_ref() const { return tsd_ref_; } - - void set_ge_ref(const std::string &op); - uint32_t ge_ref() const { return ge_ref_; } - - bool is_pynative_ge_init() { return is_pynative_ge_init_; } - void set_pynative_ge_init(bool flag) { is_pynative_ge_init_ = flag; } - - bool is_multi_graph_sink() const { return is_multi_graph_sink_; } - void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; } - - void set_enable_dynamic_mem_pool(bool enable_dynamic_mem_pool) { enable_dynamic_mem_pool_ = enable_dynamic_mem_pool; } - bool enable_dynamic_mem_pool() const { return enable_dynamic_mem_pool_; } - - void set_graph_memory_max_size(const std::string &graph_memory_max_size) { - graph_memory_max_size_ = graph_memory_max_size; - } - - void set_variable_memory_max_size(const std::string &variable_memory_max_size) { - variable_memory_max_size_ = variable_memory_max_size; - } - - const std::string &variable_memory_max_size() const { return variable_memory_max_size_; } - - const std::string &graph_memory_max_size() const { return graph_memory_max_size_; } - - void set_enable_profiling(bool flag) { profiling_mode_ = flag; } - bool enable_profiling() const { return profiling_mode_; } - - void set_profiling_options(const std::string &options) { profiling_options_ = options; } - std::string profiling_options() const { return profiling_options_; } - bool check_bprop_flag() const { return check_bprop_flag_; } - void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; } - void set_print_file_path(const std::string &file) { print_file_path_ = file; } - const std::string &print_file_path() const { return print_file_path_; } - - float max_device_memory() const { return max_device_memory_; } - void set_max_device_memory(float max_device_memory) { max_device_memory_ = max_device_memory; } - - void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; } - bool enable_graph_kernel() const { return enable_graph_kernel_; } - - bool enable_sparse() const { return enable_sparse_; } - void set_enable_sparse(bool enable_sparse) { enable_sparse_ = enable_sparse; } static void device_seter(DeviceSeter device) { seter_ = device; } static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; } std::thread tdt_print_; + template + void set_param(MsCtxParam param, const T &value) { + MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; + } + + template + const T &get_param(MsCtxParam param) const { + MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; + } + + template + void increase_param(MsCtxParam param) { + MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; + } + + template + void decrease_param(MsCtxParam param) { + MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; + } + private: inline static DeviceSeter seter_ = nullptr; inline static DeviceTypeSeter device_type_seter_ = nullptr; static std::shared_ptr inst_context_; static std::map policy_map_; + + bool bool_params_[MsCtxParam::NUM_BOOL_PARAMS]; + int int_params_[MsCtxParam::NUM_INT_PARAMS]; + uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS]; + float float_params_[MsCtxParam::NUM_FLOAT_PARAMS]; + std::string string_params_[MsCtxParam::NUM_STRING_PARAMS]; + MsBackendPolicy backend_policy_; - std::string device_target_; - uint32_t device_id_; - uint32_t max_call_depth_; - int execution_mode_; - bool enable_pynative_infer_; - bool enable_pynative_hook_; - bool save_graphs_flag_; - std::string save_graphs_path_; - uint32_t tsd_ref_; - uint32_t ge_ref_; - bool enable_task_sink_; - bool enable_hccl_; - bool precompile_only_; - bool ir_fusion_flag_; - bool auto_mixed_precision_flag_; - bool enable_reduce_precision_; - bool enable_loop_sink_; - bool enable_mem_reuse_; - bool enable_gpu_summary_; - bool enable_dump_; - std::string save_dump_path_; - bool is_multi_graph_sink_; - bool is_pynative_ge_init_; - bool enable_dynamic_mem_pool_; - std::string graph_memory_max_size_; - std::string variable_memory_max_size_; - bool profiling_mode_; - std::string profiling_options_; - bool check_bprop_flag_; - float max_device_memory_; - std::string print_file_path_; - bool enable_graph_kernel_; - bool enable_sparse_; }; + +// set method implementation for type bool/int/uint32_t/float/std::string +template <> +inline void MsContext::set_param(MsCtxParam param, const bool &value) { + bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN] = value; +} + +template <> +inline void MsContext::set_param(MsCtxParam param, const int &value) { + int_params_[param - MS_CTX_TYPE_INT_BEGIN] = value; +} + +template <> +inline void MsContext::set_param(MsCtxParam param, const uint32_t &value) { + uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN] = value; +} + +template <> +inline void MsContext::set_param(MsCtxParam param, const float &value) { + float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN] = value; +} + +template <> +inline void MsContext::set_param(MsCtxParam param, const std::string &value) { + if (seter_ != nullptr && param == MS_CTX_DEVICE_TARGET) { + MS_LOG(INFO) << "ms set context device target:" << value; + seter_(value); + } + string_params_[param - MS_CTX_TYPE_STRING_BEGIN] = value; +} + +// get method implementation for type bool/int/uint32_t/float/std::string +template <> +inline const bool &MsContext::get_param(MsCtxParam param) const { + return bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN]; +} + +template <> +inline const int &MsContext::get_param(MsCtxParam param) const { + return int_params_[param - MS_CTX_TYPE_INT_BEGIN]; +} + +template <> +inline const uint32_t &MsContext::get_param(MsCtxParam param) const { + return uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]; +} + +template <> +inline const float &MsContext::get_param(MsCtxParam param) const { + return float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN]; +} + +template <> +inline const std::string &MsContext::get_param(MsCtxParam param) const { + return string_params_[param - MS_CTX_TYPE_STRING_BEGIN]; +} + +// increate method implementation for type uint32_t +template <> +inline void MsContext::increase_param(MsCtxParam param) { + uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]++; +} + +// decreate method implementation for type uint32_t +template <> +inline void MsContext::decrease_param(MsCtxParam param) { + uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]--; +} } // namespace mindspore #endif // MINDSPORE_CORE_UTILS_MS_CONTEXT_H_ diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index faf03505c7..5dafd371b4 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -42,7 +42,7 @@ class TestOptLib : public UT::Common { parse::data_converter::ClearObjectCache(); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - ms_context->set_execution_mode(kGraphMode); + ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); } FuncGraphPtr RunTransform(FuncGraphPtr gbefore, const SubstitutionList &transform) { equiv_node.clear(); diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc index fbb3733c9b..3be96b9fed 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc @@ -112,7 +112,7 @@ TEST_F(TestHWInsertTransOp, test_insert_trans_op_for_single_output) { */ auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - ms_context->set_execution_mode(kGraphMode); + ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); auto fg = GetSingleOutputGraph("test_insert_trans_op_for_single_output", "before", "NC1HWC0"); // Do insert_trans_op_ pass of hardware opt auto graph_optimizer = std::make_shared(); diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc index a084ff477b..09bed0a913 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc @@ -112,7 +112,7 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect { TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_single_output) { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - ms_context->set_execution_mode(kGraphMode); + ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); auto kg = GetSingleOutputGraph("test_remove_internal_output_trans_op_for_single_output", "before"); // insert trans op for output auto graph_optimizer = std::make_shared(); diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc index f1a1dbe026..a199695fdc 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc @@ -104,7 +104,7 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) { */ auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - ms_context->set_execution_mode(kGraphMode); + ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transdata_split_fraz_nchw", "before"); std::vector shp{2, 4, 8, 16}; auto x_abstract = std::make_shared(kFloat32, shp); diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc index 5f16016d2c..614817d1ef 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc @@ -83,7 +83,7 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) { */ auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - ms_context->set_execution_mode(kGraphMode); + ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_transdata_fusion", "before"); std::vector shp{2, 4, 8, 16}; auto x_abstract = std::make_shared(kFloat32, shp); diff --git a/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc b/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc index 60d99d2dad..f22019845f 100644 --- a/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc +++ b/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc @@ -76,7 +76,7 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) { */ auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - ms_context->set_execution_mode(kGraphMode); + ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); FuncGraphPtr g = getPyFun_.CallAndParseRet("test_eliminate_5to4_4to5", "before"); // Renormalize func_graph to infer and set shape and type information. std::vector shp{2, 32, 224, 224}; diff --git a/tests/ut/cpp/pynative/pynative_execute_test.cc b/tests/ut/cpp/pynative/pynative_execute_test.cc index c3c848423b..e14935ab98 100644 --- a/tests/ut/cpp/pynative/pynative_execute_test.cc +++ b/tests/ut/cpp/pynative/pynative_execute_test.cc @@ -71,15 +71,15 @@ OpExecInfoPtr ConstructOpExecInfo() { TEST_F(TestPynativeExecute, TestCreateContext) { auto ctx3 = MsContext::GetInstance(); ASSERT_EQ(ctx3->backend_policy(), "vm"); - ASSERT_EQ(ctx3->device_target(), "CPU"); + ASSERT_EQ(ctx3->get_param(MS_CTX_DEVICE_TARGET), "CPU"); ctx3->set_backend_policy("ge_only"); - ctx3->set_device_target("GPU"); + ctx3->set_param(MS_CTX_DEVICE_TARGET, "GPU"); auto ctx4 = MsContext::GetInstance(); ASSERT_EQ(ctx3.get(), ctx4.get()); ASSERT_EQ(ctx4->backend_policy(), "ge_only"); - ASSERT_EQ(ctx4->device_target(), "GPU"); + ASSERT_EQ(ctx4->get_param(MS_CTX_DEVICE_TARGET), "GPU"); } TEST_F(TestPynativeExecute, TestDefaultContext) {