!5352 refactor ms_context implementation

Merge pull request !5352 from fary86/refactor_context_interface
This commit is contained in:
mindspore-ci-bot 2020-08-31 09:27:33 +08:00 committed by Gitee
commit 8d41931456
77 changed files with 582 additions and 554 deletions

View File

@ -25,7 +25,7 @@ bool HcomAllBroadCastKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> & /*outputs*/, void *stream_ptr) { const std::vector<AddressPtr> & /*outputs*/, void *stream_ptr) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_task_sink()) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
return true; return true;
} }
if (inputs.empty() || hccl_data_type_list_.empty()) { if (inputs.empty() || hccl_data_type_list_.empty()) {

View File

@ -24,7 +24,7 @@ bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &inputs, const st
const std::vector<AddressPtr> &outputs, void *stream_ptr) { const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_task_sink()) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
return true; return true;
} }
if (inputs.empty() || hccl_data_type_list_.empty()) { if (inputs.empty() || hccl_data_type_list_.empty()) {

View File

@ -24,7 +24,7 @@ bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const st
const std::vector<AddressPtr> &outputs, void *stream_ptr) { const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_task_sink()) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
return true; return true;
} }
if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) {

View File

@ -25,7 +25,7 @@ bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs, void *stream_ptr) { const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_task_sink()) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
return true; return true;
} }
if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) {

View File

@ -101,7 +101,8 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL) &&
IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) {
kernel_type = KernelType::AKG_KERNEL; kernel_type = KernelType::AKG_KERNEL;
} }

View File

@ -328,7 +328,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im
} }
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
bool is_gpu = (context->device_target() == kGPUDevice); bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) { if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) {
MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
<< ", current op num: " << op_info_.size(); << ", current op num: " << op_info_.size();

View File

@ -249,8 +249,8 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }
@ -262,7 +262,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
} }
auto optimizer = std::make_shared<GraphOptimizer>(); auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
if (context_ptr->execution_mode() == kPynativeMode) { if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>()); ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
} else { } else {
@ -276,7 +276,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
AddAscendIRFusionRulesPass(ir_fusion_pm.get()); AddAscendIRFusionRulesPass(ir_fusion_pm.get());
AddAscendIRFusionPass(ir_fusion_pm.get()); AddAscendIRFusionPass(ir_fusion_pm.get());
if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) &&
ConfigManager::GetInstance().iter_num() > 1) {
ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForGetNext>()); ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForGetNext>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
@ -296,12 +297,12 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->ir_fusion_flag()) { if (!context_ptr->get_param<bool>(MS_CTX_IR_FUSION_FLAG)) {
MS_LOG(INFO) << "IRFusion is not enable, skip"; MS_LOG(INFO) << "IRFusion is not enable, skip";
return; return;
} }
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }
@ -331,8 +332,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }
@ -367,7 +368,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
auto other2_pm = std::make_shared<PassManager>("other2_pm"); auto other2_pm = std::make_shared<PassManager>("other2_pm");
other2_pm->AddPass(std::make_shared<GetitemTuple>()); other2_pm->AddPass(std::make_shared<GetitemTuple>());
other2_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); other2_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) &&
ConfigManager::GetInstance().iter_num() > 1) {
other2_pm->AddPass(std::make_shared<GetnextMemcpyElimination>()); other2_pm->AddPass(std::make_shared<GetnextMemcpyElimination>());
} }
other2_pm->AddPass(std::make_shared<CheckConsistency>()); other2_pm->AddPass(std::make_shared<CheckConsistency>());
@ -388,11 +390,11 @@ void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &ke
bool is_before_kernel_select) { bool is_before_kernel_select) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) { if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
return; return;
} }
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }
@ -418,11 +420,11 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kern
bool is_before_kernel_select) { bool is_before_kernel_select) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) { if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
return; return;
} }
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }
@ -447,11 +449,11 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kern
void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) { void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) { if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
return; return;
} }
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }
@ -473,12 +475,12 @@ void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &ke
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) { void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->ir_fusion_flag()) { if (!context_ptr->get_param<bool>(MS_CTX_IR_FUSION_FLAG)) {
MS_LOG(INFO) << "UBFusion is not enable, skip"; MS_LOG(INFO) << "UBFusion is not enable, skip";
return; return;
} }
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }

View File

@ -53,7 +53,8 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
} }
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) { if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK)) {
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) { if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
return new_node; return new_node;
} }

View File

@ -44,7 +44,7 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) { if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return RectifyKernelInfoInPynativeProcess(node); return RectifyKernelInfoInPynativeProcess(node);
} }
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) { if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) {

View File

@ -33,8 +33,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id(); MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id();
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }

View File

@ -392,7 +392,8 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
bool IsNopNode(const AnfNodePtr &node) { bool IsNopNode(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->device_target() != kAscendDevice && context_ptr->device_target() != kGPUDevice) { if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice &&
context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
return false; return false;
} }
static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName, static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,

View File

@ -40,8 +40,8 @@ bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr>
} }
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }

View File

@ -114,7 +114,7 @@ const AnfNodePtr ConstToAttrStridedSliceGradPass::Process(const FuncGraphPtr &gr
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->device_target() == kAscendDevice) { if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
if (!CheckAttrs(strided_slice_grad)) { if (!CheckAttrs(strided_slice_grad)) {
MS_LOG(INFO) << "Check strided_slice_grad's attrs failed, graph not changed"; MS_LOG(INFO) << "Check strided_slice_grad's attrs failed, graph not changed";
return nullptr; return nullptr;

View File

@ -359,11 +359,11 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }
if (context_ptr->save_graphs_flag()) { if (context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
std::string file_path = save_graphs_path + "/after_erase_label_and_parameter.ir"; std::string file_path = save_graphs_path + "/after_erase_label_and_parameter.ir";
DumpIR(file_path, root_graph.get()); DumpIR(file_path, root_graph.get());
} }

View File

@ -253,7 +253,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
debugger_->PreExecute(graph); debugger_->PreExecute(graph);
} }
#endif #endif
if (ms_context->precompile_only()) { if (ms_context->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
MS_LOG(INFO) << "Precompile only, stop in build kernel step"; MS_LOG(INFO) << "Precompile only, stop in build kernel step";
} else { } else {
// alloc memory, including static memory and dynamic memory // alloc memory, including static memory and dynamic memory
@ -278,8 +278,8 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
child_graph->SetExecOrderByDefault(); child_graph->SetExecOrderByDefault();
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }
@ -436,7 +436,7 @@ void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
} }
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kGraphMode) { if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
if (raise_precision_count > 0) { if (raise_precision_count > 0) {
MS_LOG(WARNING) << "There has " << raise_precision_count MS_LOG(WARNING) << "There has " << raise_precision_count
<< " node/nodes used raise precision to selected the kernel!"; << " node/nodes used raise precision to selected the kernel!";
@ -481,8 +481,8 @@ void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_grap
device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph); device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph);
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }
@ -601,11 +601,11 @@ void AscendSession::DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs)
#ifdef ENABLE_DUMP_IR #ifdef ENABLE_DUMP_IR
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
if (!save_graphs) { if (!save_graphs) {
return; return;
} }
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }
@ -733,7 +733,7 @@ void AscendSession::MergeGraphExecOrder() {
if (graph_order.size() > 1) { if (graph_order.size() > 1) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->enable_task_sink()) { if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
MS_LOG(EXCEPTION) << "Control sink network should run with task-sink mode!"; MS_LOG(EXCEPTION) << "Control sink network should run with task-sink mode!";
} }
} }
@ -920,8 +920,8 @@ void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<st
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs) { if (save_graphs) {
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
@ -947,7 +947,7 @@ void AscendSession::SelectKernel(NotNull<KernelGraphPtr> root_graph) {
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kGraphMode) { if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
if (raise_precision_count > 0) { if (raise_precision_count > 0) {
MS_LOG(WARNING) << "There are " << raise_precision_count MS_LOG(WARNING) << "There are " << raise_precision_count
<< " node/nodes used raise precision to selected the kernel!"; << " node/nodes used raise precision to selected the kernel!";
@ -992,8 +992,8 @@ void AscendSession::RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph,
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs) { if (save_graphs) {
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";

View File

@ -76,7 +76,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>()); pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>()); pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
if (!CheckInModeBlackList(kernel_graph) && context_ptr->execution_mode() != kPynativeMode) { if (!CheckInModeBlackList(kernel_graph) && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>()); pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>()); pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>()); pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
@ -154,7 +154,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()); auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
bool need_sync = false; bool need_sync = false;
if (ms_context->enable_pynative_infer()) { if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
if (tensor_address == nullptr || tensor_address != device_address) { if (tensor_address == nullptr || tensor_address != device_address) {
need_sync = true; need_sync = true;
} }
@ -223,7 +223,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
// Prepare ms context info for dump .pb graph // Prepare ms context info for dump .pb graph
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
// Optimize // Optimize
Optimize(graph); Optimize(graph);
// Select kernel build info // Select kernel build info
@ -290,7 +290,7 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
// Summary // Summary
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_gpu_summary()) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY)) {
Summary(kernel_graph.get()); Summary(kernel_graph.get());
} }
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER

View File

@ -268,7 +268,7 @@ void MSInferSession::RegAllOp() {
return; return;
} }
Initialized = true; Initialized = true;
MsContext::GetInstance()->set_execution_mode(kGraphMode); MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
Py_Initialize(); Py_Initialize();
auto c_expression = PyImport_ImportModule("mindspore._c_expression"); auto c_expression = PyImport_ImportModule("mindspore._c_expression");
if (c_expression == nullptr) { if (c_expression == nullptr) {
@ -357,13 +357,13 @@ Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
MS_LOG(ERROR) << "Get Context failed!"; MS_LOG(ERROR) << "Get Context failed!";
return FAILED; return FAILED;
} }
ms_context->set_execution_mode(kGraphMode); ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
ms_context->set_device_id(device_id); ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
auto ajust_device = AjustTargetName(device); auto ajust_device = AjustTargetName(device);
if (ajust_device == "") { if (ajust_device == "") {
return FAILED; return FAILED;
} }
ms_context->set_device_target(device); ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, device);
if (!context::OpenTsd(ms_context)) { if (!context::OpenTsd(ms_context)) {
MS_LOG(ERROR) << "Session init OpenTsd failed!"; MS_LOG(ERROR) << "Session init OpenTsd failed!";
return FAILED; return FAILED;

View File

@ -93,10 +93,11 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
// if in paynative mode,data only copyed to host when user want to print data // if in paynative mode,data only copyed to host when user want to print data
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() != kPynativeMode && ms_context->device_target() != kGPUDevice) { if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
tensor->set_need_sync(true); tensor->set_need_sync(true);
} }
if (ms_context->execution_mode() != kPynativeMode) { if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
tensor->SetNeedWait(true); tensor->SetNeedWait(true);
} }
tensor->set_dirty(false); tensor->set_dirty(false);
@ -938,7 +939,7 @@ bool TensorNeedSync(const AnfNodePtr &parameter, const tensor::TensorPtr &tensor
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0); auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0);
if (ms_context->enable_pynative_infer()) { if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
return tensor->device_address().get() == nullptr || tensor->device_address() != device_address; return tensor->device_address().get() == nullptr || tensor->device_address() != device_address;
} }
if (tensor->is_dirty()) { if (tensor->is_dirty()) {
@ -979,7 +980,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
MS_EXCEPTION_IF_NULL(input_node); MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
if (ms_context->execution_mode() == kPynativeMode || if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) { AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
tensor->set_device_address(device_address); tensor->set_device_address(device_address);
} }
@ -1177,7 +1178,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
if (backend_anf != nullptr) { if (backend_anf != nullptr) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->execution_mode() == kPynativeMode) { if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return backend_anf; return backend_anf;
} }

View File

@ -118,7 +118,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
debugger_ = Debugger::GetInstance(); debugger_ = Debugger::GetInstance();
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
debugger_->Init(device_id_, ms_context->device_target()); debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET));
} }
#endif #endif

View File

@ -53,7 +53,7 @@ bool DataDumpParser::DumpEnabled() const {
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
if (context->execution_mode() == kPynativeMode) { if (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
MS_LOG(EXCEPTION) << "[DataDump] PyNative mode not support data dump"; MS_LOG(EXCEPTION) << "[DataDump] PyNative mode not support data dump";
} }
return true; return true;

View File

@ -142,7 +142,7 @@ void Debugger::EnableDebugger() {
// switch memory reuse on or off // switch memory reuse on or off
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
context_ptr->set_enable_mem_reuse(partial_memory_); context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, partial_memory_);
// print some message about memory reuse to user // print some message about memory reuse to user
if (partial_memory_) { if (partial_memory_) {
MS_LOG(WARNING) << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first " MS_LOG(WARNING) << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "

View File

@ -530,7 +530,7 @@ void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) {
MS_LOG(ERROR) << "ms_context is nullptr"; MS_LOG(ERROR) << "ms_context is nullptr";
return; return;
} }
auto save_graphs_path = ms_context->save_graphs_path(); auto save_graphs_path = ms_context->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }

View File

@ -112,7 +112,7 @@ bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
// dump_enable_ is true, close mem reuse // dump_enable_ is true, close mem reuse
context_ptr->set_enable_mem_reuse(!dump_enable_); context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, !dump_enable_);
trans_flag_ = trans_flag; trans_flag_ = trans_flag;
dump_mode_ = mode; dump_mode_ = mode;
dump_path_ = path; dump_path_ = path;
@ -135,7 +135,7 @@ bool Dump::SetDumpConfFromJsonFile() {
} }
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
auto id = context_ptr->device_id(); auto id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
char real_path[PATH_MAX] = {0}; char real_path[PATH_MAX] = {0};
if (nullptr == realpath(config_path_str, real_path)) { if (nullptr == realpath(config_path_str, real_path)) {
MS_LOG(ERROR) << "Env e2e dump path error, " << config_path_str; MS_LOG(ERROR) << "Env e2e dump path error, " << config_path_str;

View File

@ -34,7 +34,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
manager_ptr->AddFuncGraph(func_graph); manager_ptr->AddFuncGraph(func_graph);
auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) { auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) {
if (MsContext::GetInstance()->is_multi_graph_sink()) { if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
} }

View File

@ -182,7 +182,7 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp
void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) { void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) {
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
bool check_bprop_flag = context->check_bprop_flag(); bool check_bprop_flag = context->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG);
// Skip checking if check_bprop not set // Skip checking if check_bprop not set
if (!check_bprop_flag) { if (!check_bprop_flag) {
return; return;

View File

@ -29,7 +29,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
PConstant const_2(node); PConstant const_2(node);
PConstant any_const(node); PConstant any_const(node);
if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
MATCH_REPLACE(node, x + zero_, x); // Add by zero MATCH_REPLACE(node, x + zero_, x); // Add by zero
MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero
@ -41,7 +41,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
} }
// Prim Eliminate (identity) // Prim Eliminate (identity)
MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x); MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x);
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return nullptr; return nullptr;
} }
@ -75,7 +75,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
} }
AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) { AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return nullptr; return nullptr;
} }
PatternNode x, y; PatternNode x, y;

View File

@ -181,7 +181,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
} }
}; };
use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func();
if (is_on_debug_ && MsContext::GetInstance()->save_graphs_flag()) { if (is_on_debug_ && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
auto fg_name = auto fg_name =
"opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];

View File

@ -217,8 +217,8 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
void DrawNode(string name, AnfNodePtr node) { void DrawNode(string name, AnfNodePtr node) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
bool save_graphs = context_ptr->save_graphs_flag(); bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
auto save_graphs_path = context_ptr->save_graphs_path(); auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }

View File

@ -44,7 +44,7 @@ std::vector<PrimitivePtr> FindPrimtive(const FuncGraphPtr &graph, const std::str
} }
void DumpGraph(const FuncGraphPtr &root, const std::string &name) { void DumpGraph(const FuncGraphPtr &root, const std::string &name) {
if (MsContext::GetInstance()->save_graphs_flag()) { if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
draw::Draw(name + ".dot", root); draw::Draw(name + ".dot", root);
DumpIR(name + ".ir", root); DumpIR(name + ".ir", root);
ExportIR(name + ".dat", "0", root); ExportIR(name + ".dat", "0", root);

View File

@ -69,7 +69,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
struct timeval start_time, end_time; struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr); (void)gettimeofday(&start_time, nullptr);
if (MsContext::GetInstance()->save_graphs_flag()) { if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root); draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root);
} }
MS_LOG(INFO) << "Now entering step auto parallel"; MS_LOG(INFO) << "Now entering step auto parallel";

View File

@ -271,7 +271,7 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
if (!result) { if (!result) {
MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first; MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first;
} }
if (MsContext::GetInstance()->save_graphs_flag() && res->func_graph() != nullptr) { if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && res->func_graph() != nullptr) {
auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first; auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first;
auto func_graph = res->func_graph(); auto func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
@ -295,20 +295,20 @@ bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res,
static bool IsCtrlSink() { static bool IsCtrlSink() {
auto ms_ctx = MsContext::GetInstance(); auto ms_ctx = MsContext::GetInstance();
if (ms_ctx->execution_mode() != kGraphMode) { if (ms_ctx->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode) {
return false; return false;
} }
std::string device_target = ms_ctx->device_target(); std::string device_target = ms_ctx->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_target != kAscendDevice) { if (device_target != kAscendDevice) {
return false; return false;
} }
if (!ms_ctx->enable_task_sink()) { if (!ms_ctx->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
return false; return false;
} }
if (!ms_ctx->is_multi_graph_sink()) { if (!ms_ctx->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
return false; return false;
} }
return true; return true;
@ -325,13 +325,13 @@ bool TaskEmitAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (CompileGraphs::ContainMixedTarget(func_graph)) { if (CompileGraphs::ContainMixedTarget(func_graph)) {
bc_ptr->set_is_multi_graph_sink(false); bc_ptr->set_is_multi_graph_sink(false);
context_ptr->set_is_multi_graph_sink(false); context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
context_ptr->set_loop_sink_flag(false); context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
} else if (context_ptr->execution_mode() != kPynativeMode) { } else if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
std::string device_target = context_ptr->device_target(); std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_target == kAscendDevice && backend != kMsVm) { if (device_target == kAscendDevice && backend != kMsVm) {
bc_ptr->set_is_multi_graph_sink(true); bc_ptr->set_is_multi_graph_sink(true);
context_ptr->set_is_multi_graph_sink(true); context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
} }
} }

View File

@ -49,7 +49,7 @@ inline std::string GetFilePathName(const std::string &file_name) {
if (ms_context == nullptr) { if (ms_context == nullptr) {
MS_LOG(EXCEPTION) << "ms_context is nullptr"; MS_LOG(EXCEPTION) << "ms_context is nullptr";
} }
auto save_graphs_path = ms_context->save_graphs_path(); auto save_graphs_path = ms_context->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }

View File

@ -48,6 +48,56 @@ using OpLib = mindspore::kernel::OpLib;
using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy; using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy;
using ParallelContext = mindspore::parallel::ParallelContext; using ParallelContext = mindspore::parallel::ParallelContext;
using CostModelContext = mindspore::parallel::CostModelContext; using CostModelContext = mindspore::parallel::CostModelContext;
using mindspore::MsCtxParam;
namespace mindspore {
void MsCtxSetParameter(std::shared_ptr<MsContext> ctx, MsCtxParam param, const py::object &value) {
MS_LOG(DEBUG) << "set param(" << param << ") with value '" << py::str(value) << "' of type '"
<< py::str(value.get_type()) << "'.";
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END && py::isinstance<py::bool_>(value)) {
ctx->set_param<bool>(param, value.cast<bool>());
return;
}
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END && py::isinstance<py::int_>(value)) {
ctx->set_param<int>(param, value.cast<int>());
return;
}
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END && py::isinstance<py::int_>(value)) {
ctx->set_param<uint32_t>(param, value.cast<uint32_t>());
return;
}
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END && py::isinstance<py::float_>(value)) {
ctx->set_param<float>(param, value.cast<float>());
return;
}
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END && py::isinstance<py::str>(value)) {
ctx->set_param<std::string>(param, value.cast<std::string>());
return;
}
MS_LOG(EXCEPTION) << "Got illegal param " << param << " and value with type " << py::str(value.get_type());
}
py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam param) {
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END) {
return py::bool_(ctx->get_param<bool>(param));
}
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END) {
return py::int_(ctx->get_param<int>(param));
}
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END) {
return py::int_(ctx->get_param<uint32_t>(param));
}
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END) {
return py::float_(ctx->get_param<float>(param));
}
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END) {
return py::str(ctx->get_param<std::string>(param));
}
MS_LOG(EXCEPTION) << "Got illegal param " << param << ".";
}
} // namespace mindspore
// Interface with python // Interface with python
PYBIND11_MODULE(_c_expression, m) { PYBIND11_MODULE(_c_expression, m) {
@ -101,53 +151,48 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
(void)m.def("ms_ctx_get_param", &mindspore::MsCtxGetParameter, "Get value of specified paramter.");
(void)m.def("ms_ctx_set_param", &mindspore::MsCtxSetParameter, "Set value for specified paramter.");
(void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic())
.value("auto_mixed_precision_flag", MsCtxParam::MS_CTX_AUTO_MIXED_PRECISION_FLAG)
.value("check_bprop_flag", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG)
.value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP)
.value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL)
.value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY)
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
.value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL)
.value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK)
.value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE)
.value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK)
.value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER)
.value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION)
.value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE)
.value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK)
.value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG)
.value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK)
.value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT)
.value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY)
.value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING)
.value("save_graphs_flag", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
.value("execution_mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
.value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
.value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH)
.value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS)
.value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH)
.value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH)
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
.value("ge_ref", MsCtxParam::MS_CTX_GE_REF)
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
.value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF);
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(m, "MSContext") (void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
.def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.") .def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.")
.def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.") .def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.");
.def("get_execution_mode", &mindspore::MsContext::execution_mode, "Get execution mode.")
.def("set_execution_mode", &mindspore::MsContext::set_execution_mode, "Set execution mode.")
.def("set_precompile_only", &mindspore::MsContext::set_precompile_only, "Set enable precompile only.")
.def("get_precompile_only", &mindspore::MsContext::precompile_only, "Get enable precompile only.")
.def("get_device_target", &mindspore::MsContext::device_target, "Get device target.")
.def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.")
.def("get_device_id", &mindspore::MsContext::device_id, "Get device id.")
.def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.")
.def("get_max_call_depth", &mindspore::MsContext::max_call_depth, "Get max call depth.")
.def("set_max_call_depth", &mindspore::MsContext::set_max_call_depth, "Set max call depth.")
.def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.")
.def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.")
.def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag,
"Get whether to enable auto mixed precision.")
.def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag,
"Set whether to enable auto mixed precision.")
.def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision,
"Get whether to enable reduce precision.")
.def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision,
"Set whether to enable reduce precision.")
.def("get_save_graphs_path", &mindspore::MsContext::save_graphs_path, "Get save graphs path.")
.def("set_save_graphs_path", &mindspore::MsContext::set_save_graphs_path, "Set save graphs path.")
.def("get_enable_dump", &mindspore::MsContext::enable_dump, "Get whether to enable dump.")
.def("set_enable_dump", &mindspore::MsContext::set_enable_dump, "Set whether to enable dump.")
.def("get_save_dump_path", &mindspore::MsContext::save_dump_path, "Get path to dump.")
.def("set_save_dump_path", &mindspore::MsContext::set_save_dump_path, "Set path to dump.")
.def("set_graph_memory_max_size", &mindspore::MsContext::set_graph_memory_max_size, "set graph memory max size.")
.def("set_variable_memory_max_size", &mindspore::MsContext::set_variable_memory_max_size,
"set variable memory max size")
.def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.")
.def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.")
.def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.")
.def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.")
.def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.")
.def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.")
.def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.")
.def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.")
.def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.")
.def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel,
"Set the GraphKernel switch to on or off.")
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.")
.def("get_enable_sparse", &mindspore::MsContext::enable_sparse, "Get whether to enable sparsity.")
.def("set_enable_sparse", &mindspore::MsContext::set_enable_sparse, "Set whether to enable sparsity.");
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")

View File

@ -271,7 +271,7 @@ void InitOpt(const ResourcePtr &res) {
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) { if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
g_pass_opts["opt_graph_kernel_a"]->set_enable(false); g_pass_opts["opt_graph_kernel_a"]->set_enable(false);
g_pass_opts["opt_graph_kernel_b"]->set_enable(false); g_pass_opts["opt_graph_kernel_b"]->set_enable(false);
} }

View File

@ -88,7 +88,7 @@ std::string GetBaseNameForIR(int stage_idx, const std::string &action_name) {
if (ms_context == nullptr) { if (ms_context == nullptr) {
MS_LOG(EXCEPTION) << "ms_context is nullptr"; MS_LOG(EXCEPTION) << "ms_context is nullptr";
} }
auto save_graphs_path = ms_context->save_graphs_path(); auto save_graphs_path = ms_context->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) { if (save_graphs_path.empty()) {
save_graphs_path = "."; save_graphs_path = ".";
} }
@ -646,7 +646,7 @@ void Pipeline::Run() {
if (!result) { if (!result) {
MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first; MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first;
} }
if (MsContext::GetInstance()->save_graphs_flag() && resource_->func_graph() != nullptr) { if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && resource_->func_graph() != nullptr) {
auto graph = resource_->func_graph(); auto graph = resource_->func_graph();
if (graph != nullptr) { if (graph != nullptr) {
user_graph = graph; user_graph = graph;
@ -688,7 +688,7 @@ void Pipeline::Run() {
MsProfile::Reset(); MsProfile::Reset();
#endif #endif
if (MsContext::GetInstance()->save_graphs_flag() && (user_graph != nullptr)) { if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && (user_graph != nullptr)) {
std::string user_graph_file = GetFilePathName("ModelDigraph.dot"); std::string user_graph_file = GetFilePathName("ModelDigraph.dot");
MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file; MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file;
draw::DrawUserFuncGraph(user_graph_file, user_graph); draw::DrawUserFuncGraph(user_graph_file, user_graph);
@ -710,7 +710,7 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef
if (!succ) { if (!succ) {
MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed.";
} }
if (MsContext::GetInstance()->execution_mode() == 0 && !converted->isa<tensor::Tensor>()) { if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == 0 && !converted->isa<tensor::Tensor>()) {
MS_EXCEPTION(TypeError) << "For 'graph mode', the " << i << "th arg: " << converted->ToString() MS_EXCEPTION(TypeError) << "For 'graph mode', the " << i << "th arg: " << converted->ToString()
<< " is not tensor."; << " is not tensor.";
} }
@ -891,7 +891,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
// Convert CNodeList to LinConvertResult. // Convert CNodeList to LinConvertResult.
ConfigManager::GetInstance().set_iter_num(1); ConfigManager::GetInstance().set_iter_num(1);
auto runner = convert_fn({app_init}, ""); auto runner = convert_fn({app_init}, "");
if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
backend->Link(runner.graph_id); backend->Link(runner.graph_id);
} }
ConfigManager::GetInstance().set_iter_num(size); ConfigManager::GetInstance().set_iter_num(size);
@ -965,10 +965,11 @@ void InitHccl() {
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
(void)context::OpenTsd(ms_context); (void)context::OpenTsd(ms_context);
uint32_t device_id = ms_context->device_id(); uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
std::string device_name = ms_context->device_target(); std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
ms_context->set_enable_hccl(true); ms_context->set_param<bool>(MS_CTX_ENABLE_HCCL, true);
if (ms_context->backend_policy() == "ms" && ms_context->device_target() == kAscendDevice) { if (ms_context->backend_policy() == "ms" &&
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
if (!runtime_instance->Init()) { if (!runtime_instance->Init()) {

View File

@ -214,7 +214,7 @@ bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::di
return false; return false;
} }
if (MsContext::GetInstance()->save_graphs_flag()) { if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
convertor.DrawComputeGraph(GetFilePathName("ge_graph.dot")); // for debug convertor.DrawComputeGraph(GetFilePathName("ge_graph.dot")); // for debug
convertor.DrawInitGraph(GetFilePathName("init_graph.dot")); // for debug convertor.DrawInitGraph(GetFilePathName("init_graph.dot")); // for debug
convertor.DrawSaveCheckpointGraph(GetFilePathName("save_checkpoint_graph.dot")); // for debug convertor.DrawSaveCheckpointGraph(GetFilePathName("save_checkpoint_graph.dot")); // for debug
@ -244,7 +244,7 @@ FuncGraphPtr BuildDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, co
} }
FuncGraphPtr anf_graph = info.at(phase)->func_graph; FuncGraphPtr anf_graph = info.at(phase)->func_graph;
if (MsContext::GetInstance()->save_graphs_flag()) { if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
draw::Draw(GetFilePathName("anf_graph.dot"), anf_graph); // for debug draw::Draw(GetFilePathName("anf_graph.dot"), anf_graph); // for debug
DumpIR(GetFilePathName("anf_graph.ir"), anf_graph, true); DumpIR(GetFilePathName("anf_graph.ir"), anf_graph, true);
} }

View File

@ -118,8 +118,9 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
<< ", current function call depth: " << engine->function_call_depth(); << ", current function call depth: " << engine->function_call_depth();
AbstractBasePtr ret_base = nullptr; AbstractBasePtr ret_base = nullptr;
engine->IncreaseFunctionCallDepth(); engine->IncreaseFunctionCallDepth();
if (engine->function_call_depth() > MsContext::GetInstance()->max_call_depth()) { if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit " << MsContext::GetInstance()->max_call_depth() << "."; MS_LOG(EXCEPTION) << "Exceed function call depth limit "
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) << ".";
} }
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node); std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
@ -409,7 +410,7 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
bparams.push_back(SensitivityTransform(orig_func_)); bparams.push_back(SensitivityTransform(orig_func_));
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse(); bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
[&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { [&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr {
if (enable_sparse && arg_spec->isa<AbstractTensor>()) { if (enable_sparse && arg_spec->isa<AbstractTensor>()) {

View File

@ -62,7 +62,7 @@ class Evaluator : public Base {
virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) { virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) {
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse(); bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (!enable_sparse) { if (!enable_sparse) {
return nullptr; return nullptr;
} }

View File

@ -290,7 +290,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
if (abs_base->isa<AbstractTensor>()) { if (abs_base->isa<AbstractTensor>()) {
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base); auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
dic["shape"] = arg_tensor->shape()->shape(); dic["shape"] = arg_tensor->shape()->shape();
if (MsContext::GetInstance()->execution_mode() == kGraphMode) { if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
const auto &min_shape = arg_tensor->shape()->min_shape(); const auto &min_shape = arg_tensor->shape()->min_shape();
const auto &max_shape = arg_tensor->shape()->max_shape(); const auto &max_shape = arg_tensor->shape()->max_shape();
if (!min_shape.empty() && !max_shape.empty()) { if (!min_shape.empty() && !max_shape.empty()) {

View File

@ -558,8 +558,8 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(op_exec_info);
MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms";
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
ms_context->set_enable_pynative_infer(true); ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
std::string device_target = ms_context->device_target(); std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_target != kAscendDevice && device_target != kGPUDevice) { if (device_target != kAscendDevice && device_target != kGPUDevice) {
MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode";
} }
@ -567,7 +567,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
if (session == nullptr) { if (session == nullptr) {
session = session::SessionFactory::Get().Create(device_target); session = session::SessionFactory::Get().Create(device_target);
MS_EXCEPTION_IF_NULL(session); MS_EXCEPTION_IF_NULL(session);
session->Init(ms_context->device_id()); session->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID));
} }
std::vector<tensor::TensorPtr> input_tensors; std::vector<tensor::TensorPtr> input_tensors;
@ -578,7 +578,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
session->BuildOpAsync(op_exec_info.get(), graph_info, input_tensors, tensors_mask); session->BuildOpAsync(op_exec_info.get(), graph_info, input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, &input_tensors); EraseValueNodeTensor(tensors_mask, &input_tensors);
py::tuple result = session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors); py::tuple result = session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors);
ms_context->set_enable_pynative_infer(false); ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
*status = PYNATIVE_SUCCESS; *status = PYNATIVE_SUCCESS;
MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms"; MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms";
return result; return result;
@ -1308,7 +1308,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
// Maybe exit in the pynative runing op, so need reset pynative flag. // Maybe exit in the pynative runing op, so need reset pynative flag.
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
if (ms_context != nullptr) { if (ms_context != nullptr) {
ms_context->set_enable_pynative_infer(false); ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
} }
ConfigManager::GetInstance().ResetIterNum(); ConfigManager::GetInstance().ResetIterNum();
return; return;

View File

@ -89,7 +89,7 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
MS_EXCEPTION(ValueError) << "For user define net bprop, the gradients number: " << grads.size() MS_EXCEPTION(ValueError) << "For user define net bprop, the gradients number: " << grads.size()
<< " is not equal to the args number: " << py_args.size() - 2 << "."; << " is not equal to the args number: " << py_args.size() - 2 << ".";
} }
if (MsContext::GetInstance()->check_bprop_flag()) { if (MsContext::GetInstance()->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG)) {
for (size_t i = 0; i < grads.size(); i++) { for (size_t i = 0; i < grads.size(); i++) {
if (py::isinstance<tensor::Tensor>(py_args[i])) { if (py::isinstance<tensor::Tensor>(py_args[i])) {
if (!py::isinstance<tensor::Tensor>(grads[i])) { if (!py::isinstance<tensor::Tensor>(grads[i])) {

View File

@ -154,7 +154,7 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s
DeviceAddressPtr AssignLaunchMemory(size_t size, const std::string &format, TypeId type) { DeviceAddressPtr AssignLaunchMemory(size_t size, const std::string &format, TypeId type) {
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
auto device_id = ms_context->device_id(); auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
auto address_ptr = runtime_instance->AssignSingleOpLaunchMemory(size, format, type); auto address_ptr = runtime_instance->AssignSingleOpLaunchMemory(size, format, type);
@ -261,11 +261,12 @@ void AscendDeviceAddress::SyncStream() const {
MS_LOG(INFO) << "Start!"; MS_LOG(INFO) << "Start!";
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() != kPynativeMode && !ms_context->enable_pynative_infer()) { if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
MS_LOG(INFO) << "Finish!"; MS_LOG(INFO) << "Finish!";
return; return;
} }
auto device_id = ms_context->device_id(); auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
auto ret = runtime_instance->SyncStream(); auto ret = runtime_instance->SyncStream();
@ -348,7 +349,7 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v
} }
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
auto device_id = ms_context->device_id(); auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
auto ret = auto ret =
@ -475,7 +476,8 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh
std::vector<size_t> device_shape = GetDeviceShape(&host_shape); std::vector<size_t> device_shape = GetDeviceShape(&host_shape);
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() != kGraphMode && ms_context->execution_mode() != kPynativeMode && if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode &&
ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
type_id_name_map.find(type_id_) != type_id_name_map.end()) { type_id_name_map.find(type_id_) != type_id_name_map.end()) {
std::pair<std::string, std::string> type_format = std::make_pair(type_id_name_map.at(type_id_), format_); std::pair<std::string, std::string> type_format = std::make_pair(type_id_name_map.at(type_id_), format_);
if (use_trans_data.find(type_format) != use_trans_data.end()) { if (use_trans_data.find(type_format) != use_trans_data.end()) {

View File

@ -158,7 +158,7 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std
bool AscendKernelRuntime::NeedDestroyHccl() { bool AscendKernelRuntime::NeedDestroyHccl() {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->enable_hccl()) { if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
MS_LOG(INFO) << "Hccl is not enabled"; MS_LOG(INFO) << "Hccl is not enabled";
return false; return false;
} }
@ -177,7 +177,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
auto ret = rtSetDevice(context_ptr->device_id()); auto ret = rtSetDevice(context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID));
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast<int>(ret) << "]"; MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast<int>(ret) << "]";
} }
@ -461,12 +461,12 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id(); MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id();
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->enable_task_sink(); bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (!is_task_sink) { if (!is_task_sink) {
return true; return true;
} }
#ifdef MEM_REUSE_DEBUG #ifdef MEM_REUSE_DEBUG
if (!context_ptr->enable_mem_reuse()) { if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_REUSE)) {
// Get normal graph ir for memreuse // Get normal graph ir for memreuse
mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph); mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph);
} }
@ -518,7 +518,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id(); MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id();
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->enable_task_sink(); bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (!is_task_sink) { if (!is_task_sink) {
return true; return true;
} }
@ -658,7 +658,7 @@ bool AscendKernelRuntime::InitDevice() {
MS_LOG(ERROR) << "Get MsContext instance failed"; MS_LOG(ERROR) << "Get MsContext instance failed";
return false; return false;
} }
if (context_ptr->enable_hccl()) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
if (!HcclInit()) { if (!HcclInit()) {
MS_LOG(ERROR) << "HcclInit init failed"; MS_LOG(ERROR) << "HcclInit init failed";
return false; return false;
@ -746,7 +746,7 @@ bool AscendKernelRuntime::DestroyHccl() {
return false; return false;
} }
MS_LOG(INFO) << "Hccl destroy successful, status = " << res << "."; MS_LOG(INFO) << "Hccl destroy successful, status = " << res << ".";
context_ptr->set_enable_hccl(false); context_ptr->set_param<bool>(MS_CTX_ENABLE_HCCL, false);
return true; return true;
} }

View File

@ -43,7 +43,7 @@ void AscendMemoryManager::MallocDeviceMemory() {
uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() {
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
auto variable_memory_max_size = context->variable_memory_max_size(); auto variable_memory_max_size = context->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
if (variable_memory_max_size == "0") { if (variable_memory_max_size == "0") {
return 0; return 0;
} }

View File

@ -1373,7 +1373,7 @@ vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::it
bool AscendStreamAssign::IsTaskSink() { bool AscendStreamAssign::IsTaskSink() {
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (!ms_context->enable_task_sink()) { if (!ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
MS_LOG(INFO) << "Task sink mode is not enable"; MS_LOG(INFO) << "Task sink mode is not enable";
return false; return false;
} else { } else {

View File

@ -117,7 +117,7 @@ void DataDumper::SetOpMappingInfo(NotNull<aicpu::dump::OpMappingInfo *> dump_inf
if (!dump_path.has_value()) { if (!dump_path.has_value()) {
MS_LOG(EXCEPTION) << "Dump path invalid"; MS_LOG(EXCEPTION) << "Dump path invalid";
} }
auto device_id = context_ptr->device_id(); auto device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
dump_info->set_dump_path("/" + dump_path.value() + "_" + std::to_string(device_id) + "/"); dump_info->set_dump_path("/" + dump_path.value() + "_" + std::to_string(device_id) + "/");
MS_LOG(INFO) << "[DataDump] dump_path:" << dump_path.value(); MS_LOG(INFO) << "[DataDump] dump_path:" << dump_path.value();

View File

@ -363,7 +363,7 @@ void PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index,
*precision_reduce = false; *precision_reduce = false;
return; return;
} }
if (context_ptr->enable_reduce_precision()) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION)) {
selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
kernel_support_datatype, &kernel_match_datatype_idx_copy); kernel_support_datatype, &kernel_match_datatype_idx_copy);
} }

View File

@ -117,7 +117,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
} }
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
const string prof_options_str = context->profiling_options(); const string prof_options_str = context->get_param<std::string>(MS_CTX_PROFILING_OPTIONS);
std::vector<string> opts = Split(prof_options_str, ':'); std::vector<string> opts = Split(prof_options_str, ':');
if (opts.empty()) { if (opts.empty()) {
MS_LOG(WARNING) << "Profiling is enabled, but profiling option is not set!"; MS_LOG(WARNING) << "Profiling is enabled, but profiling option is not set!";

View File

@ -41,7 +41,7 @@ class ProfilingManager {
inline bool IsProfiling() const { inline bool IsProfiling() const {
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
return context->enable_profiling(); return context->get_param<bool>(MS_CTX_ENABLE_PROFILING);
} }
protected: protected:

View File

@ -342,12 +342,12 @@ void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids,
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
TaskDescReporter task_reporter(context->device_id(), "vm.task_desc_info", ret->second); TaskDescReporter task_reporter(context->get_param<uint32_t>(MS_CTX_DEVICE_ID), "vm.task_desc_info", ret->second);
task_reporter.set_task_ids(task_ids); task_reporter.set_task_ids(task_ids);
task_reporter.set_stream_ids(stream_ids); task_reporter.set_stream_ids(stream_ids);
task_reporter.ReportData(); task_reporter.ReportData();
GraphDescReporter graph_reporter(context->device_id(), "vm.graph_desc_info", ret->second); GraphDescReporter graph_reporter(context->get_param<uint32_t>(MS_CTX_DEVICE_ID), "vm.graph_desc_info", ret->second);
graph_profiling_cnode_.erase(ret); graph_profiling_cnode_.erase(ret);
graph_reporter.ReportData(); graph_reporter.ReportData();
@ -357,7 +357,7 @@ void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids,
MS_LOG(ERROR) << "Graph id not found in graph_point"; MS_LOG(ERROR) << "Graph id not found in graph_point";
return; return;
} }
PointReporter point_reporter(context->device_id(), "vm.point"); PointReporter point_reporter(context->get_param<uint32_t>(MS_CTX_DEVICE_ID), "vm.point");
for (const auto &point : point_iter->second) { for (const auto &point : point_iter->second) {
point_reporter.AddReportData(point); point_reporter.AddReportData(point);
} }

View File

@ -416,7 +416,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
mem_manager_->ResetDynamicMemory(); mem_manager_->ResetDynamicMemory();
AssignStaticMemoryInput(graph); AssignStaticMemoryInput(graph);
AssignStaticMemoryValueNode(graph); AssignStaticMemoryValueNode(graph);
bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); bool is_enable_dynamic_mem = context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL);
if (is_enable_dynamic_mem) { if (is_enable_dynamic_mem) {
// Use the dynamic memory pool. // Use the dynamic memory pool.
InitKernelRefCount(graph); InitKernelRefCount(graph);
@ -435,8 +435,8 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
bool ret = true; bool ret = true;
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); bool is_enable_dynamic_mem = context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL);
bool is_enable_pynative_infer = context_ptr->enable_pynative_infer(); bool is_enable_pynative_infer = context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
if (is_enable_dynamic_mem && !is_enable_pynative_infer) { if (is_enable_dynamic_mem && !is_enable_pynative_infer) {
auto graph_id = graph->graph_id(); auto graph_id = graph->graph_id();
auto iter = mem_swap_map_.find(graph_id); auto iter = mem_swap_map_.find(graph_id);

View File

@ -29,7 +29,7 @@ bool GPUMemoryAllocator::Init() {
size_t free_size = CudaDriver::free_mem_size(); size_t free_size = CudaDriver::free_mem_size();
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
limited_device_memory_ = context_ptr->max_device_memory(); limited_device_memory_ = context_ptr->get_param<float>(MS_CTX_MAX_DEVICE_MEMORY);
available_device_memory_ = FloatToSize(limited_device_memory_ * 1024 * 1024 * 1024); available_device_memory_ = FloatToSize(limited_device_memory_ * 1024 * 1024 * 1024);
if (total_size > 0 && free_size > 0 && available_device_memory_ > 0) { if (total_size > 0 && free_size > 0 && available_device_memory_ > 0) {
MS_LOG(INFO) << "GPU device total memory size " << total_size << ", current free memory size " << free_size MS_LOG(INFO) << "GPU device total memory size " << total_size << ", current free memory size " << free_size
@ -44,7 +44,7 @@ bool GPUMemoryAllocator::Init() {
void GPUMemoryAllocator::CheckMaxDeviceMemory() const { void GPUMemoryAllocator::CheckMaxDeviceMemory() const {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
auto max_device_memory = context_ptr->max_device_memory(); auto max_device_memory = context_ptr->get_param<float>(MS_CTX_MAX_DEVICE_MEMORY);
// Currently not support modifying the max device memory. // Currently not support modifying the max device memory.
if (limited_device_memory_ != max_device_memory) { if (limited_device_memory_ != max_device_memory) {
MS_LOG(EXCEPTION) MS_LOG(EXCEPTION)

View File

@ -37,7 +37,7 @@ void GPUMemoryManager::MallocDeviceMemory() {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
// If use the dynamic memory pool, then alloc the first memory block to init. // If use the dynamic memory pool, then alloc the first memory block to init.
if (context_ptr->enable_dynamic_mem_pool()) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) {
auto device_addr = MallocMemFromMemPool(1); auto device_addr = MallocMemFromMemPool(1);
if (!device_addr) { if (!device_addr) {
MS_LOG(EXCEPTION) << "Dynamic memory pool init error."; MS_LOG(EXCEPTION) << "Dynamic memory pool init error.";
@ -65,7 +65,7 @@ void GPUMemoryManager::FreeDeviceMemory() {
uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool) { uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_dynamic_mem_pool()) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) {
auto device_ptr = MallocMemFromMemPool(size); auto device_ptr = MallocMemFromMemPool(size);
MS_EXCEPTION_IF_NULL(device_ptr); MS_EXCEPTION_IF_NULL(device_ptr);
return AddressOffset(device_ptr, 0); return AddressOffset(device_ptr, 0);

View File

@ -162,7 +162,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type) { bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type) {
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) { if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return false; return false;
} }
if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) { if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) {

View File

@ -60,8 +60,8 @@ void KernelAdjust::ReorderGetNext(const std::shared_ptr<session::KernelGraph> &k
bool KernelAdjust::NeedInsertSwitch() { bool KernelAdjust::NeedInsertSwitch() {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
return (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && return (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) &&
ConfigManager::GetInstance().iter_num() > 1); context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && ConfigManager::GetInstance().iter_num() > 1);
} }
CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr,

View File

@ -50,7 +50,7 @@ bool KernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
struct timeval start_time, end_time; struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr); (void)gettimeofday(&start_time, nullptr);
#endif #endif
bool is_task_sink = context_ptr->enable_task_sink(); bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (is_task_sink) { if (is_task_sink) {
ret = RunTask(graph); ret = RunTask(graph);
} else { } else {
@ -502,7 +502,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
MS_LOG(INFO) << "communication op addr exist"; MS_LOG(INFO) << "communication op addr exist";
continue; continue;
} }
if (context_ptr->enable_hccl()) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
mem_size = mem_manager_->GetCommonAlignSize(mem_size); mem_size = mem_manager_->GetCommonAlignSize(mem_size);
} }
total_size += mem_size; total_size += mem_size;
@ -646,7 +646,8 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
DeviceAddressPtr address = nullptr; DeviceAddressPtr address = nullptr;
address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id); address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id);
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, node_size)) { if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
!mem_manager_->MallocMemFromMemPool(address, node_size)) {
MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << node_size; MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << node_size;
} else if (mem_manager_->MallocMem(kStaticMem, node_size, address) == nullptr) { } else if (mem_manager_->MallocMem(kStaticMem, node_size, address) == nullptr) {
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size; MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size;
@ -682,7 +683,8 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
DeviceAddressPtr address = nullptr; DeviceAddressPtr address = nullptr;
address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) { if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
!mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << tensor_size; MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << tensor_size;
} else if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) { } else if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
@ -701,7 +703,7 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_manager_);
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool is_enable_mem_reuse = context_ptr->enable_mem_reuse(); bool is_enable_mem_reuse = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_REUSE);
auto mem_type = kDynamicMem; auto mem_type = kDynamicMem;
if (is_enable_mem_reuse) { if (is_enable_mem_reuse) {
mem_manager_->MallocReusedDynamicMem(graph); mem_manager_->MallocReusedDynamicMem(graph);

View File

@ -54,7 +54,7 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, Me
uint8_t *ptr = nullptr; uint8_t *ptr = nullptr;
if (AnfAlgo::IsCommunicationOp(node)) { if (AnfAlgo::IsCommunicationOp(node)) {
bool communication_mem = false; bool communication_mem = false;
if (context_ptr->enable_hccl()) { if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
communication_mem = true; communication_mem = true;
} }
if (type == kStaticMem) { if (type == kStaticMem) {

View File

@ -1070,7 +1070,7 @@ void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNod
convertor.inputs_ = inputs; convertor.inputs_ = inputs;
(void)convertor.ConvertAllNode().BuildGraph(); (void)convertor.ConvertAllNode().BuildGraph();
std::string name = graph_node->ToString() + "_ge_graph.dot"; std::string name = graph_node->ToString() + "_ge_graph.dot";
if (MsContext::GetInstance()->save_graphs_flag()) { if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
convertor.DrawComputeGraph(name); convertor.DrawComputeGraph(name);
} }
branches_map_[node.get()] = *(convertor.df_graph_); branches_map_[node.get()] = *(convertor.df_graph_);

View File

@ -41,13 +41,13 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
MS_LOG(EXCEPTION) << "nullptr"; MS_LOG(EXCEPTION) << "nullptr";
} }
if (ms_context_ptr->is_pynative_ge_init()) { if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
return true; return true;
} }
if (ms_context_ptr->tsd_ref()) { if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
MS_LOG(DEBUG) << "TDT Dataset client is already opened."; MS_LOG(DEBUG) << "TDT Dataset client is already opened.";
ms_context_ptr->set_tsd_ref("++"); ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
return true; return true;
} }
@ -59,7 +59,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
unsigned int device_id; unsigned int device_id;
unsigned int rank_size = 1; unsigned int rank_size = 1;
device_id = ms_context_ptr->device_id(); device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto rank_size_env = common::GetEnv("RANK_SIZE"); auto rank_size_env = common::GetEnv("RANK_SIZE");
if (rank_size_env.empty()) { if (rank_size_env.empty()) {
@ -79,7 +79,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << "."; MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << ".";
return false; return false;
} }
ms_context_ptr->set_tsd_ref("++"); ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
int32_t initStatus = tdt::TdtHostInit(device_id); int32_t initStatus = tdt::TdtHostInit(device_id);
if (initStatus != TDT_OK_CODE) { if (initStatus != TDT_OK_CODE) {
@ -88,7 +88,8 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
} }
ms_context_ptr->tdt_print_ = std::thread(TensorPrint()); ms_context_ptr->tdt_print_ = std::thread(TensorPrint());
#endif #endif
MS_LOG(INFO) << "Open and init tsd successful, tsd reference = " << ms_context_ptr->tsd_ref() << "."; MS_LOG(INFO) << "Open and init tsd successful, tsd reference = "
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
return true; return true;
} }
@ -96,12 +97,12 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
if (ms_context_ptr == nullptr) { if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "nullptr"; MS_LOG(EXCEPTION) << "nullptr";
} }
if (ms_context_ptr->tsd_ref() == 0) { if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
return true; return true;
} }
ms_context_ptr->set_tsd_ref("--"); ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF);
if (force || ms_context_ptr->tsd_ref() == 0) { if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
ms_context_ptr->set_tsd_ref(" "); ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
#ifdef ENABLE_TDTQUE #ifdef ENABLE_TDTQUE
int32_t stopStatus = tdt::TdtHostStop(KNpuLog); int32_t stopStatus = tdt::TdtHostStop(KNpuLog);
if (stopStatus != TDT_OK_CODE) { if (stopStatus != TDT_OK_CODE) {
@ -123,17 +124,17 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
MS_LOG(ERROR) << "tdt thread join failed: " << e.what(); MS_LOG(ERROR) << "tdt thread join failed: " << e.what();
} }
#endif #endif
auto device_id = ms_context_ptr->device_id(); auto device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
TDT_StatusT status = TsdClose(device_id); TDT_StatusT status = TsdClose(device_id);
if (status != TDT_OK) { if (status != TDT_OK) {
MS_LOG(EXCEPTION) << "Close tsd failed, status = " << status << "."; MS_LOG(EXCEPTION) << "Close tsd failed, status = " << status << ".";
return false; return false;
} }
ms_context_ptr->set_pynative_ge_init(false); ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
MS_LOG(INFO) << "Destroy and close tsd successful, status = " << status << "."; MS_LOG(INFO) << "Destroy and close tsd successful, status = " << status << ".";
} else { } else {
MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = " << ms_context_ptr->tsd_ref() MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = "
<< "."; << ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
} }
return true; return true;
@ -159,14 +160,14 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
} }
#ifdef ENABLE_GE #ifdef ENABLE_GE
(*ge_options)["device_id"] = "0"; (*ge_options)["device_id"] = "0";
(*ge_options)["ge.exec.enableDump"] = std::to_string(ms_context_ptr->enable_dump()); (*ge_options)["ge.exec.enableDump"] = std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP));
(*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->save_dump_path(); (*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH);
(*ge_options)["ge.exec.dumpMode"] = "output"; (*ge_options)["ge.exec.dumpMode"] = "output";
MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->enable_dump()) MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP))
<< " and save dump path is " << ms_context_ptr->save_dump_path() << "."; << " and save dump path is " << ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH) << ".";
(*ge_options)["ge.exec.profilingMode"] = std::to_string(ms_context_ptr->enable_profiling()); (*ge_options)["ge.exec.profilingMode"] = std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_PROFILING));
if (ms_context_ptr->enable_profiling()) { if (ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_PROFILING)) {
(*ge_options)["ge.exec.profilingOptions"] = ms_context_ptr->profiling_options(); (*ge_options)["ge.exec.profilingOptions"] = ms_context_ptr->get_param<std::string>(MS_CTX_PROFILING_OPTIONS);
} }
(*ge_options)["rank_table_file"] = ""; (*ge_options)["rank_table_file"] = "";
@ -178,12 +179,12 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
} }
(*ge_options)["graphType"] = "1"; (*ge_options)["graphType"] = "1";
if (ms_context_ptr->graph_memory_max_size() != "0") { if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
(*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->graph_memory_max_size(); (*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
} }
if (ms_context_ptr->variable_memory_max_size() != "0") { if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
(*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->variable_memory_max_size(); (*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
} }
#if ENABLE_TRAIN == 1 #if ENABLE_TRAIN == 1
@ -224,7 +225,7 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
} }
// Enable auto mixed precision according to the context options // Enable auto mixed precision according to the context options
if (ms_context_ptr->auto_mixed_precision_flag()) { if (ms_context_ptr->get_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG)) {
(*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision"; (*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision";
} else { } else {
(*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16"; (*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
@ -240,7 +241,7 @@ void SetHcclOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<s
} }
auto env_table_file = common::GetEnv("RANK_TABLE_FILE"); auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
auto env_rank_id = common::GetEnv("RANK_ID"); auto env_rank_id = common::GetEnv("RANK_ID");
auto env_device_id = std::to_string(ms_context_ptr->device_id()); auto env_device_id = std::to_string(ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID));
if (!(env_table_file.empty() || env_rank_id.empty())) { if (!(env_table_file.empty() || env_rank_id.empty())) {
MS_LOG(INFO) << "Initialize Ge for distribute parameter"; MS_LOG(INFO) << "Initialize Ge for distribute parameter";
MS_LOG(INFO) << "Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH."; MS_LOG(INFO) << "Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
@ -275,12 +276,12 @@ bool InitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
MS_LOG(EXCEPTION) << "nullptr"; MS_LOG(EXCEPTION) << "nullptr";
} }
#ifdef ENABLE_GE #ifdef ENABLE_GE
if (ms_context_ptr->is_pynative_ge_init()) { if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
return true; return true;
} }
if (ms_context_ptr->ge_ref()) { if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF)) {
ms_context_ptr->set_ge_ref("++"); ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
return true; return true;
} }
@ -293,8 +294,8 @@ bool InitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
MS_LOG(EXCEPTION) << "Initialize GE failed!"; MS_LOG(EXCEPTION) << "Initialize GE failed!";
} }
} }
ms_context_ptr->set_ge_ref("++"); ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->ge_ref() << "."; MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
#endif #endif
return true; return true;
} }
@ -303,12 +304,13 @@ bool PynativeInitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
if (ms_context_ptr == nullptr) { if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "nullptr"; MS_LOG(EXCEPTION) << "nullptr";
} }
if (ms_context_ptr->is_pynative_ge_init() || ms_context_ptr->ge_ref() || ms_context_ptr->tsd_ref()) { if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT) ||
ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
return true; return true;
} }
(void)OpenTsd(ms_context_ptr); (void)OpenTsd(ms_context_ptr);
(void)InitGe(ms_context_ptr); (void)InitGe(ms_context_ptr);
ms_context_ptr->set_pynative_ge_init(true); ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, true);
return true; return true;
} }
@ -317,12 +319,12 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
MS_LOG(EXCEPTION) << "nullptr"; MS_LOG(EXCEPTION) << "nullptr";
} }
#ifdef ENABLE_GE #ifdef ENABLE_GE
if (ms_context_ptr->ge_ref() == 0) { if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
return true; return true;
} }
ms_context_ptr->set_ge_ref("--"); ms_context_ptr->decrease_param<uint32_t>(MS_CTX_GE_REF);
if (force || ms_context_ptr->ge_ref() == 0) { if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
ms_context_ptr->set_ge_ref(" "); ms_context_ptr->set_param<uint32_t>(MS_CTX_GE_REF, 0);
try { try {
DfGraphManager::GetInstance().DeleteGraphRunner(); DfGraphManager::GetInstance().DeleteGraphRunner();
DfGraphManager::GetInstance().DeleteGeSession(); DfGraphManager::GetInstance().DeleteGeSession();
@ -337,7 +339,8 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
} }
ms_context_ptr->set_pynative_ge_init(false); ms_context_ptr->set_pynative_ge_init(false);
} else { } else {
MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = " << ms_context_ptr->ge_ref() << "."; MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = "
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
} }
#endif #endif
return true; return true;
@ -347,14 +350,14 @@ bool IsTsdOpened(const std::shared_ptr<MsContext> &ms_context_ptr) {
if (ms_context_ptr == nullptr) { if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "nullptr"; MS_LOG(EXCEPTION) << "nullptr";
} }
return ms_context_ptr->IsTsdOpened(); return ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) > 0;
} }
bool IsGeInited(const std::shared_ptr<MsContext> &ms_context_ptr) { bool IsGeInited(const std::shared_ptr<MsContext> &ms_context_ptr) {
if (ms_context_ptr == nullptr) { if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "nullptr"; MS_LOG(EXCEPTION) << "nullptr";
} }
return ms_context_ptr->IsGeInited(); return ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) > 0;
} }
// Register for device type. // Register for device type.

View File

@ -353,7 +353,7 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
// When sparse enabled, the undetermined might be raised and eliminated in opt passes // When sparse enabled, the undetermined might be raised and eliminated in opt passes
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse(); bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (enable_sparse) { if (enable_sparse) {
return std::make_shared<abstract::AbstractUndetermined>(); return std::make_shared<abstract::AbstractUndetermined>();
} }

View File

@ -273,7 +273,7 @@ void TensorPrint::operator()() {
prntpb::Print print; prntpb::Print print;
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
std::string print_file_path = ms_context->print_file_path(); std::string print_file_path = ms_context->get_param<std::string>(MS_CTX_PRINT_FILE_PATH);
if (print_file_path == "") { if (print_file_path == "") {
while (true) { while (true) {
std::vector<tdt::DataItem> bundle; std::vector<tdt::DataItem> bundle;

View File

@ -59,7 +59,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
graph_id = target_sess_->CompileGraphAsync(lst, outputs); graph_id = target_sess_->CompileGraphAsync(lst, outputs);
} }
if (MsContext::GetInstance()->precompile_only()) { if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
MS_LOG(INFO) << "PrecompileOnly, stop run graph"; MS_LOG(INFO) << "PrecompileOnly, stop run graph";
return result; return result;
} }
@ -180,7 +180,7 @@ void MsBackend::CreateOtherSession(const std::string &target) {
} }
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
uint32_t device_id = context_ptr->device_id(); uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
other_sess_->Init(device_id); other_sess_->Init(device_id);
other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
other_device_ = target; other_device_ = target;

View File

@ -56,7 +56,7 @@ namespace {
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) { bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
std::string last_target = context_ptr->device_target(); std::string last_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
for (auto &node : nodes) { for (auto &node : nodes) {
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
std::string cur_target = GetCNodeTarget(node); std::string cur_target = GetCNodeTarget(node);
@ -348,7 +348,7 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
if (prim->name() == prim::kPrimBpropCut->name()) { if (prim->name() == prim::kPrimBpropCut->name()) {
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_enable_pynative_hook(true); ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, true);
} }
if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) { if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) {
@ -412,7 +412,7 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
if (ContainMultiTarget(nodes)) { if (ContainMultiTarget(nodes)) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->device_target(); std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
nodes = SplitSort(graph, default_target); nodes = SplitSort(graph, default_target);
return SplitNodesWithTarget(nodes, graph); return SplitNodesWithTarget(nodes, graph);
} }
@ -920,17 +920,17 @@ BackendPtr CreateBackend() {
} }
if (name == kMsConvert) { if (name == kMsConvert) {
std::string target = context_ptr->device_target(); std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = context_ptr->device_id(); uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto backend = std::make_shared<MsBackend>(name, target, device_id); auto backend = std::make_shared<MsBackend>(name, target, device_id);
std::string device_target = MsContext::GetInstance()->device_target(); std::string device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_target == kAscendDevice) { if (device_target == kAscendDevice) {
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
backend->set_is_multi_graph_sink(false); backend->set_is_multi_graph_sink(false);
context_ptr->set_is_multi_graph_sink(false); context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
} else { } else {
backend->set_is_multi_graph_sink(true); backend->set_is_multi_graph_sink(true);
context_ptr->set_is_multi_graph_sink(true); context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
} }
} }
return backend; return backend;

View File

@ -22,7 +22,7 @@ import threading
from collections import namedtuple from collections import namedtuple
from types import FunctionType from types import FunctionType
from mindspore import log as logger from mindspore import log as logger
from mindspore._c_expression import MSContext from mindspore._c_expression import MSContext, ms_ctx_param, ms_ctx_get_param, ms_ctx_set_param
from mindspore._checkparam import args_type_check from mindspore._checkparam import args_type_check
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
_reset_auto_parallel_context _reset_auto_parallel_context
@ -157,9 +157,15 @@ class _Context:
raise ValueError("Context handle is none in context!!!") raise ValueError("Context handle is none in context!!!")
return value return value
def get_param(self, param):
return ms_ctx_get_param(self._context_handle, param)
def set_param(self, param, value):
ms_ctx_set_param(self._context_handle, param, value)
@property @property
def mode(self): def mode(self):
return self._context_handle.get_execution_mode() return self.get_param(ms_ctx_param.execution_mode)
@mode.setter @mode.setter
def mode(self, mode): def mode(self, mode):
@ -169,15 +175,17 @@ class _Context:
Args: Args:
mode (int): GRAPH_MODE or PYNATIVE_MODE. mode (int): GRAPH_MODE or PYNATIVE_MODE.
""" """
self._context_handle.set_execution_mode(mode)
if mode == PYNATIVE_MODE: if mode == PYNATIVE_MODE:
if self.enable_debug_runtime: if self.enable_debug_runtime:
self.set_backend_policy("vm") self.set_backend_policy("vm")
self._context_switches.push(True, None) self._context_switches.push(True, None)
else: elif mode == GRAPH_MODE:
if self.enable_debug_runtime: if self.enable_debug_runtime:
self.set_backend_policy("ge") self.set_backend_policy("ge")
self._context_switches.push(False, None) self._context_switches.push(False, None)
else:
raise ValueError(f'The execution mode {mode} is invalid!')
self.set_param(ms_ctx_param.execution_mode, mode)
def set_backend_policy(self, policy): def set_backend_policy(self, policy):
success = self._context_handle.set_backend_policy(policy) success = self._context_handle.set_backend_policy(policy)
@ -186,110 +194,106 @@ class _Context:
@property @property
def precompile_only(self): def precompile_only(self):
return self._context_handle.get_precompile_only() return self.get_param(ms_ctx_param.precompile_only)
@precompile_only.setter @precompile_only.setter
def precompile_only(self, precompile_only): def precompile_only(self, precompile_only):
self._context_handle.set_precompile_only(precompile_only) self.set_param(ms_ctx_param.precompile_only, precompile_only)
@property @property
def save_graphs(self): def save_graphs(self):
return self._context_handle.get_save_graphs_flag() return self.get_param(ms_ctx_param.save_graphs_flag)
@save_graphs.setter @save_graphs.setter
def save_graphs(self, save_graphs_flag): def save_graphs(self, save_graphs_flag):
self._context_handle.set_save_graphs_flag(save_graphs_flag) self.set_param(ms_ctx_param.save_graphs_flag, save_graphs_flag)
@property @property
def save_graphs_path(self): def save_graphs_path(self):
return self._context_handle.get_save_graphs_path() return self.get_param(ms_ctx_param.save_graphs_path)
@save_graphs_path.setter @save_graphs_path.setter
def save_graphs_path(self, save_graphs_path): def save_graphs_path(self, save_graphs_path):
self._context_handle.set_save_graphs_path( self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path))
_make_directory(save_graphs_path))
@property @property
def device_target(self): def device_target(self):
return self._context_handle.get_device_target() return self.get_param(ms_ctx_param.device_target)
@device_target.setter @device_target.setter
def device_target(self, target): def device_target(self, target):
success = self._context_handle.set_device_target(target) valid_targets = ["CPU", "GPU", "Ascend", "Davinci"]
if not success: if not target in valid_targets:
raise ValueError("Target device name is invalid!!!") raise ValueError(f"Target device name {target} is invalid! It must be one of {valid_targets}")
if self.enable_debug_runtime and self.device_target == "CPU": if target == "Davinci":
target = "Ascend"
self.set_param(ms_ctx_param.device_target, target)
if self.enable_debug_runtime and target == "CPU":
self.set_backend_policy("vm") self.set_backend_policy("vm")
@property @property
def device_id(self): def device_id(self):
return self._context_handle.get_device_id() return self.get_param(ms_ctx_param.device_id)
@device_id.setter @device_id.setter
def device_id(self, device_id): def device_id(self, device_id):
if device_id < 0 or device_id > 4095: if device_id < 0 or device_id > 4095:
raise ValueError( raise ValueError(f"Device id must be in [0, 4095], but got {device_id}")
"Device id must be in [0, 4095], but got {}".format(device_id)) self.set_param(ms_ctx_param.device_id, device_id)
success = self._context_handle.set_device_id(device_id)
if not success:
raise RuntimeError("Device id set failed!!!")
@property @property
def max_call_depth(self): def max_call_depth(self):
return self._context_handle.get_max_call_depth() return self.get_param(ms_ctx_param.max_call_depth)
@max_call_depth.setter @max_call_depth.setter
def max_call_depth(self, max_call_depth): def max_call_depth(self, max_call_depth):
if max_call_depth <= 0: if max_call_depth <= 0:
raise ValueError( raise ValueError(f"Max call depth must be greater than 0, but got {max_call_depth}")
"Max call depth must be greater than 0, but got {}".format(max_call_depth)) self.set_param(ms_ctx_param.max_call_depth, max_call_depth)
self._context_handle.set_max_call_depth(max_call_depth)
@property @property
def enable_auto_mixed_precision(self): def enable_auto_mixed_precision(self):
return self._context_handle.get_auto_mixed_precision_flag() return self.get_param(ms_ctx_param.auto_mixed_precision_flag)
@enable_auto_mixed_precision.setter @enable_auto_mixed_precision.setter
def enable_auto_mixed_precision(self, enable_auto_mixed_precision): def enable_auto_mixed_precision(self, enable_auto_mixed_precision):
self._context_handle.set_auto_mixed_precision_flag( self.set_param(ms_ctx_param.auto_mixed_precision_flag, enable_auto_mixed_precision)
enable_auto_mixed_precision)
@property @property
def enable_reduce_precision(self): def enable_reduce_precision(self):
return self._context_handle.get_enable_reduce_precision_flag() return self.get_param(ms_ctx_param.enable_reduce_precision_flag)
@enable_reduce_precision.setter @enable_reduce_precision.setter
def enable_reduce_precision(self, enable_reduce_precision): def enable_reduce_precision(self, enable_reduce_precision):
self._context_handle.set_enable_reduce_precision_flag( self.set_param(ms_ctx_param.enable_reduce_precision_flag, enable_reduce_precision)
enable_reduce_precision)
@property @property
def enable_dump(self): def enable_dump(self):
return self._context_handle.get_enable_dump() return self.get_param(ms_ctx_param.enable_dump)
@enable_dump.setter @enable_dump.setter
def enable_dump(self, enable_dump): def enable_dump(self, enable_dump):
self._context_handle.set_enable_dump(enable_dump) self.set_param(ms_ctx_param.enable_dump, enable_dump)
@property @property
def save_dump_path(self): def save_dump_path(self):
return self._context_handle.get_save_dump_path() return self.get_param(ms_ctx_param.save_dump_path)
@save_dump_path.setter @save_dump_path.setter
def save_dump_path(self, save_dump_path): def save_dump_path(self, save_dump_path):
self._context_handle.set_save_dump_path(save_dump_path) self.set_param(ms_ctx_param.save_dump_path, save_dump_path)
@property @property
def enable_profiling(self): def enable_profiling(self):
return self._context_handle.get_enable_profiling() return self.get_param(ms_ctx_param.enable_profiling)
@enable_profiling.setter @enable_profiling.setter
def enable_profiling(self, flag): def enable_profiling(self, flag):
self._context_handle.set_enable_profiling(flag) self.set_param(ms_ctx_param.enable_profiling, flag)
@property @property
def profiling_options(self): def profiling_options(self):
return self._context_handle.get_profiling_options() return self.get_param(ms_ctx_param.profiling_options)
@profiling_options.setter @profiling_options.setter
def profiling_options(self, option): def profiling_options(self, option):
@ -298,15 +302,15 @@ class _Context:
if option not in options: if option not in options:
raise ValueError("Profiling options must be in 'training_trace' 'task_trace' " raise ValueError("Profiling options must be in 'training_trace' 'task_trace' "
"'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'.") "'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'.")
self._context_handle.set_profiling_options(option) self.set_param(ms_ctx_param.profiling_options, option)
@property @property
def enable_graph_kernel(self): def enable_graph_kernel(self):
return self._context_handle.get_enable_graph_kernel() return self.get_param(ms_ctx_param.enable_graph_kernel)
@enable_graph_kernel.setter @enable_graph_kernel.setter
def enable_graph_kernel(self, graph_kernel_switch_): def enable_graph_kernel(self, graph_kernel_switch_):
self._context_handle.set_enable_graph_kernel(graph_kernel_switch_) self.set_param(ms_ctx_param.enable_graph_kernel, graph_kernel_switch_)
@property @property
def reserve_class_name_in_scope(self): def reserve_class_name_in_scope(self):
@ -325,20 +329,14 @@ class _Context:
@variable_memory_max_size.setter @variable_memory_max_size.setter
def variable_memory_max_size(self, variable_memory_max_size): def variable_memory_max_size(self, variable_memory_max_size):
if not check_input_format(variable_memory_max_size): if not check_input_format(variable_memory_max_size):
raise ValueError( raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
"Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
raise ValueError( raise ValueError("Context param variable_memory_max_size should be less than 31GB.")
"Context param variable_memory_max_size should be less than 31GB.") variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
variable_memory_max_size_ = variable_memory_max_size[:- graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2])
2] + " * 1024 * 1024 * 1024" graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024"
graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - \ self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_)
int(variable_memory_max_size[:-2]) self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_)
graph_memory_max_size_ = str(
graph_memory_max_size) + " * 1024 * 1024 * 1024"
self._context_handle.set_variable_memory_max_size(
variable_memory_max_size_)
self._context_handle.set_graph_memory_max_size(graph_memory_max_size_)
@property @property
def enable_ge(self): def enable_ge(self):
@ -355,15 +353,15 @@ class _Context:
@property @property
def check_bprop(self): def check_bprop(self):
return self._context_handle.get_check_bprop_flag() return self.get_param(ms_ctx_param.check_bprop_flag)
@check_bprop.setter @check_bprop.setter
def check_bprop(self, check_bprop_flag): def check_bprop(self, check_bprop_flag):
self._context_handle.set_check_bprop_flag(check_bprop_flag) self.set_param(ms_ctx_param.check_bprop_flag, check_bprop_flag)
@property @property
def max_device_memory(self): def max_device_memory(self):
return self._context_handle.get_max_device_memory() return self.get_param(ms_ctx_param.max_device_memory)
@max_device_memory.setter @max_device_memory.setter
def max_device_memory(self, max_device_memory): def max_device_memory(self, max_device_memory):
@ -372,7 +370,7 @@ class _Context:
max_device_memory_value = float(max_device_memory[:-2]) max_device_memory_value = float(max_device_memory[:-2])
if max_device_memory_value == 0: if max_device_memory_value == 0:
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
self._context_handle.set_max_device_memory(max_device_memory_value) self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value)
@property @property
def print_file_path(self): def print_file_path(self):
@ -392,15 +390,15 @@ class _Context:
full_file_name = os.path.join(path, file_name) full_file_name = os.path.join(path, file_name)
else: else:
full_file_name = print_file_path full_file_name = print_file_path
self._context_handle.set_print_file_path(full_file_name) self.set_param(ms_ctx_param.print_file_path, full_file_name)
@property @property
def enable_sparse(self): def enable_sparse(self):
return self._context_handle.get_enable_sparse() return self.get_param(ms_ctx_param.enable_sparse)
@enable_sparse.setter @enable_sparse.setter
def enable_sparse(self, enable_sparse): def enable_sparse(self, enable_sparse):
self._context_handle.set_enable_sparse(enable_sparse) self.set_param(ms_ctx_param.enable_sparse, enable_sparse)
def check_input_format(x): def check_input_format(x):
import re import re
@ -486,8 +484,6 @@ def set_auto_parallel_context(**kwargs):
full_batch (bool): Whether to load the whole batch on each device. Default: False. full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in
data parallel training in the benefit of time and memory saving. data parallel training in the benefit of time and memory saving.
max_call_depth(int): Specify the function call depth limit. Default: 1000.
Raises: Raises:
ValueError: If input key is not attribute in auto parallel context. ValueError: If input key is not attribute in auto parallel context.
@ -501,7 +497,6 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(parameter_broadcast=False) >>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(max_call_depth=80)
""" """
_set_auto_parallel_context(**kwargs) _set_auto_parallel_context(**kwargs)
@ -603,6 +598,7 @@ def set_context(**kwargs):
a file by default, and turn off printing to the screen. If the file already exists, add a timestamp a file by default, and turn off printing to the screen. If the file already exists, add a timestamp
suffix to the file. suffix to the file.
enable_sparse (bool): Whether to enable sparsity feature. Default: False. enable_sparse (bool): Whether to enable sparsity feature. Default: False.
max_call_depth(int): Specify the function call depth limit. Default: 1000.
Raises: Raises:
ValueError: If input key is not an attribute in context. ValueError: If input key is not an attribute in context.
@ -623,6 +619,7 @@ def set_context(**kwargs):
>>> context.set_context(enable_profiling=True, profiling_options="training_trace") >>> context.set_context(enable_profiling=True, profiling_options="training_trace")
>>> context.set_context(max_device_memory="3.5GB") >>> context.set_context(max_device_memory="3.5GB")
>>> context.set_context(print_file_path="print.pb") >>> context.set_context(print_file_path="print.pb")
>>> context.set_context(max_call_depth=80)
""" """
for key, value in kwargs.items(): for key, value in kwargs.items():
if not hasattr(_context(), key): if not hasattr(_context(), key):

View File

@ -51,7 +51,7 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse(); bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (enable_sparse && dflt->isa<AbstractTensor>()) { if (enable_sparse && dflt->isa<AbstractTensor>()) {
auto dflt_tensor = dflt->cast<AbstractTensorPtr>(); auto dflt_tensor = dflt->cast<AbstractTensorPtr>();
return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone()); return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone());

View File

@ -232,7 +232,7 @@ std::string GetMaketupleNodeTarget(const CNodePtr &cnode) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->device_target(); std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
return default_target; return default_target;
} }
@ -248,7 +248,7 @@ std::string GetTupleGetItemTarget(const CNodePtr &cnode, const PrimitivePtr &pri
std::string GetCNodeTarget(const AnfNodePtr &node) { std::string GetCNodeTarget(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->device_target(); std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
return default_target; return default_target;
} }

View File

@ -652,7 +652,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
new_func_graph->set_param_default_value(item.first, cloner[item.second]); new_func_graph->set_param_default_value(item.first, cloner[item.second]);
} }
if (MsContext::GetInstance()->is_multi_graph_sink()) { if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
} }

View File

@ -30,7 +30,7 @@ abstract::AbstractBasePtr MetaFuncGraph::ToAbstract() {
FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) { FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) {
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse(); bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (!enable_sparse) { if (!enable_sparse) {
return nullptr; return nullptr;
} }

View File

@ -32,49 +32,50 @@ std::map<std::string, MsBackendPolicy> MsContext::policy_map_ = {{"ge", kMsBacke
{"vm_prior", kMsBackendVmPrior}}; {"vm_prior", kMsBackendVmPrior}};
MsContext::MsContext(const std::string &policy, const std::string &target) { MsContext::MsContext(const std::string &policy, const std::string &target) {
save_graphs_flag_ = false; set_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG, false);
save_graphs_path_ = "."; set_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH, ".");
enable_dump_ = false; set_param<std::string>(MS_CTX_SAVE_DUMP_PATH, ".");
save_dump_path_ = "."; set_param<uint32_t>(MS_CTX_TSD_REF, 0);
tsd_ref_ = 0; set_param<uint32_t>(MS_CTX_GE_REF, 0);
ge_ref_ = 0; set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
is_multi_graph_sink_ = false; set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
is_pynative_ge_init_ = false; set_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION, true);
enable_reduce_precision_ = true;
auto env_device = common::GetEnv("DEVICE_ID"); auto env_device = common::GetEnv("DEVICE_ID");
if (!env_device.empty()) { if (!env_device.empty()) {
device_id_ = UlongToUint(std::stoul(env_device.c_str())); uint32_t device_id = UlongToUint(std::stoul(env_device.c_str()));
set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
} else { } else {
device_id_ = 0; set_param<uint32_t>(MS_CTX_DEVICE_ID, 0);
} }
max_call_depth_ = MAX_CALL_DEPTH_DEFAULT; set_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH, MAX_CALL_DEPTH_DEFAULT);
backend_policy_ = policy_map_[policy]; set_param<std::string>(MS_CTX_DEVICE_TARGET, target);
device_target_ = target; set_param<int>(MS_CTX_EXECUTION_MODE, kPynativeMode);
execution_mode_ = kPynativeMode; set_param<bool>(MS_CTX_ENABLE_TASK_SINK, true);
enable_task_sink_ = true; set_param<bool>(MS_CTX_IR_FUSION_FLAG, true);
ir_fusion_flag_ = true; set_param<bool>(MS_CTX_ENABLE_HCCL, false);
enable_hccl_ = false;
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
enable_mem_reuse_ = false; set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, false);
#else #else
enable_mem_reuse_ = true; set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, true);
#endif #endif
enable_gpu_summary_ = true; set_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY, true);
precompile_only_ = false; set_param<bool>(MS_CTX_PRECOMPILE_ONLY, false);
auto_mixed_precision_flag_ = false; set_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG, false);
enable_pynative_infer_ = false; set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
enable_pynative_hook_ = false; set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, false);
enable_dynamic_mem_pool_ = true; set_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true);
graph_memory_max_size_ = "0"; set_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE, "0");
variable_memory_max_size_ = "0"; set_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE, "0");
enable_loop_sink_ = target == kAscendDevice || target == kDavinciDevice; set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, target == kAscendDevice || target == kDavinciDevice);
profiling_mode_ = false; set_param<bool>(MS_CTX_ENABLE_PROFILING, false);
profiling_options_ = "training_trace"; set_param<std::string>(MS_CTX_PROFILING_OPTIONS, "training_trace");
check_bprop_flag_ = false; set_param<bool>(MS_CTX_CHECK_BPROP_FLAG, false);
max_device_memory_ = kDefaultMaxDeviceMemory; set_param<float>(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory);
print_file_path_ = ""; set_param<std::string>(MS_CTX_PRINT_FILE_PATH, "");
enable_graph_kernel_ = false; set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
enable_sparse_ = false; set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
backend_policy_ = policy_map_[policy];
} }
std::shared_ptr<MsContext> MsContext::GetInstance() { std::shared_ptr<MsContext> MsContext::GetInstance() {
@ -106,54 +107,4 @@ std::string MsContext::backend_policy() const {
} }
return "unknown"; return "unknown";
} }
void MsContext::set_execution_mode(int execution_mode) {
if (execution_mode != kGraphMode && execution_mode != kPynativeMode) {
MS_LOG(EXCEPTION) << "The execution mode is invalid!";
}
execution_mode_ = execution_mode;
}
bool MsContext::set_device_target(const std::string &target) {
if (kTargetSet.find(target) == kTargetSet.end()) {
MS_LOG(ERROR) << "invalid device target name: " << target;
return false;
}
if (target == kDavinciDevice) {
device_target_ = kAscendDevice;
} else {
device_target_ = target;
}
if (seter_) {
seter_(device_target_);
}
MS_LOG(INFO) << "ms set context device target:" << target;
return true;
}
bool MsContext::set_device_id(uint32_t device_id) {
device_id_ = device_id;
MS_LOG(INFO) << "ms set context device id:" << device_id;
return true;
}
void MsContext::set_tsd_ref(const std::string &op) {
if (op == "--") {
tsd_ref_--;
} else if (op == "++") {
tsd_ref_++;
} else {
tsd_ref_ = 0;
}
}
void MsContext::set_ge_ref(const std::string &op) {
if (op == "--") {
ge_ref_--;
} else if (op == "++") {
ge_ref_++;
} else {
ge_ref_ = 0;
}
}
} // namespace mindspore } // namespace mindspore

View File

@ -49,6 +49,69 @@ const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice,
// The default max available device memory is 1024GB. // The default max available device memory is 1024GB.
const float kDefaultMaxDeviceMemory = 1024; const float kDefaultMaxDeviceMemory = 1024;
// enum definition for MindSpore Context Parameter
enum MsCtxParam : unsigned {
// paramater of type bool
MS_CTX_TYPE_BOOL_BEGIN,
MS_CTX_AUTO_MIXED_PRECISION_FLAG = MS_CTX_TYPE_BOOL_BEGIN,
MS_CTX_CHECK_BPROP_FLAG,
MS_CTX_ENABLE_DUMP,
MS_CTX_ENABLE_DYNAMIC_MEM_POOL,
MS_CTX_ENABLE_GPU_SUMMARY,
MS_CTX_ENABLE_GRAPH_KERNEL,
MS_CTX_ENABLE_HCCL,
MS_CTX_ENABLE_LOOP_SINK,
MS_CTX_ENABLE_MEM_REUSE,
MS_CTX_ENABLE_PYNATIVE_HOOK,
MS_CTX_ENABLE_PYNATIVE_INFER,
MS_CTX_ENABLE_REDUCE_PRECISION,
MS_CTX_ENABLE_SPARSE,
MS_CTX_ENABLE_TASK_SINK,
MS_CTX_IR_FUSION_FLAG,
MS_CTX_IS_MULTI_GRAPH_SINK,
MS_CTX_IS_PYNATIVE_GE_INIT,
MS_CTX_PRECOMPILE_ONLY,
MS_CTX_ENABLE_PROFILING,
MS_CTX_SAVE_GRAPHS_FLAG,
MS_CTX_TYPE_BOOL_END,
// paramater of type int
MS_CTX_TYPE_INT_BEGIN = MS_CTX_TYPE_BOOL_END,
MS_CTX_EXECUTION_MODE = MS_CTX_TYPE_INT_BEGIN,
MS_CTX_TYPE_INT_END,
// paramater of type uint32
MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END,
MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN,
MS_CTX_GE_REF,
MS_CTX_MAX_CALL_DEPTH,
MS_CTX_TSD_REF,
MS_CTX_TYPE_UINT32_END,
// paramater of type float
MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END,
MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN,
MS_CTX_TYPE_FLOAT_END,
// paramater of type string
MS_CTX_TYPE_STRING_BEGIN = MS_CTX_TYPE_FLOAT_END,
MS_CTX_DEVICE_TARGET = MS_CTX_TYPE_STRING_BEGIN,
MS_CTX_GRAPH_MEMORY_MAX_SIZE,
MS_CTX_PRINT_FILE_PATH,
MS_CTX_PROFILING_OPTIONS,
MS_CTX_SAVE_DUMP_PATH,
MS_CTX_SAVE_GRAPHS_PATH,
MS_CTX_VARIABLE_MEMORY_MAX_SIZE,
MS_CTX_TYPE_STRING_END,
// parameter numbers of each type
NUM_BOOL_PARAMS = MS_CTX_TYPE_BOOL_END - MS_CTX_TYPE_BOOL_BEGIN,
NUM_INT_PARAMS = MS_CTX_TYPE_INT_END - MS_CTX_TYPE_INT_BEGIN,
NUM_UINT32_PARAMS = MS_CTX_TYPE_UINT32_END - MS_CTX_TYPE_UINT32_BEGIN,
NUM_FLOAT_PARAMS = MS_CTX_TYPE_FLOAT_END - MS_CTX_TYPE_FLOAT_BEGIN,
NUM_STRING_PARAMS = MS_CTX_TYPE_STRING_END - MS_CTX_TYPE_STRING_BEGIN
};
class MsContext { class MsContext {
public: public:
MsContext(const std::string &backend_policy, const std::string &target); MsContext(const std::string &backend_policy, const std::string &target);
@ -62,156 +125,113 @@ class MsContext {
std::string backend_policy() const; std::string backend_policy() const;
bool set_backend_policy(const std::string &policy); bool set_backend_policy(const std::string &policy);
int execution_mode() const { return execution_mode_; }
void set_execution_mode(int execution_mode);
bool enable_pynative_infer() const { return enable_pynative_infer_; }
void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; }
bool enable_pynative_hook() const { return enable_pynative_hook_; }
void set_enable_pynative_hook(bool enable_pynative_hook) { enable_pynative_hook_ = enable_pynative_hook; }
bool enable_task_sink() const { return enable_task_sink_; }
void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; }
bool precompile_only() const { return precompile_only_; }
std::string device_target() const { return device_target_; }
bool set_device_target(const std::string &target);
uint32_t device_id() const { return device_id_; }
bool set_device_id(uint32_t device_id);
// uint32_t max_call_depth_
uint32_t max_call_depth() const { return max_call_depth_; }
inline bool set_max_call_depth(uint32_t max_call_depth) {
max_call_depth_ = max_call_depth;
return true;
}
bool save_graphs_flag() const { return save_graphs_flag_; }
void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; }
std::string save_graphs_path() const { return save_graphs_path_; }
void set_save_graphs_path(const std::string &save_paths) { save_graphs_path_ = save_paths; }
bool IsGeInited() { return ge_ref_ > 0; }
void set_enable_hccl(bool enable_hccl) { enable_hccl_ = enable_hccl; }
bool enable_hccl() const { return enable_hccl_; }
bool ir_fusion_flag() const { return ir_fusion_flag_; }
bool loop_sink_flag() const { return enable_loop_sink_; }
void set_loop_sink_flag(bool enable_loop_sink) { enable_loop_sink_ = enable_loop_sink; }
void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; }
bool enable_mem_reuse() const { return enable_mem_reuse_; }
void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; }
bool enable_gpu_summary() const { return enable_gpu_summary_; }
void set_auto_mixed_precision_flag(bool auto_mixed_precision_flag) {
auto_mixed_precision_flag_ = auto_mixed_precision_flag;
}
bool auto_mixed_precision_flag() const { return auto_mixed_precision_flag_; }
void set_enable_reduce_precision(bool flag) { enable_reduce_precision_ = flag; }
bool enable_reduce_precision() const { return enable_reduce_precision_; }
void set_enable_dump(bool flag) { enable_dump_ = flag; }
bool enable_dump() const { return enable_dump_; }
void set_save_dump_path(const std::string &path) { save_dump_path_ = path; }
std::string save_dump_path() const { return save_dump_path_; }
bool IsTsdOpened() const { return tsd_ref_ > 0; }
void set_tsd_ref(const std::string &op);
uint32_t tsd_ref() const { return tsd_ref_; }
void set_ge_ref(const std::string &op);
uint32_t ge_ref() const { return ge_ref_; }
bool is_pynative_ge_init() { return is_pynative_ge_init_; }
void set_pynative_ge_init(bool flag) { is_pynative_ge_init_ = flag; }
bool is_multi_graph_sink() const { return is_multi_graph_sink_; }
void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; }
void set_enable_dynamic_mem_pool(bool enable_dynamic_mem_pool) { enable_dynamic_mem_pool_ = enable_dynamic_mem_pool; }
bool enable_dynamic_mem_pool() const { return enable_dynamic_mem_pool_; }
void set_graph_memory_max_size(const std::string &graph_memory_max_size) {
graph_memory_max_size_ = graph_memory_max_size;
}
void set_variable_memory_max_size(const std::string &variable_memory_max_size) {
variable_memory_max_size_ = variable_memory_max_size;
}
const std::string &variable_memory_max_size() const { return variable_memory_max_size_; }
const std::string &graph_memory_max_size() const { return graph_memory_max_size_; }
void set_enable_profiling(bool flag) { profiling_mode_ = flag; }
bool enable_profiling() const { return profiling_mode_; }
void set_profiling_options(const std::string &options) { profiling_options_ = options; }
std::string profiling_options() const { return profiling_options_; }
bool check_bprop_flag() const { return check_bprop_flag_; }
void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; }
void set_print_file_path(const std::string &file) { print_file_path_ = file; }
const std::string &print_file_path() const { return print_file_path_; }
float max_device_memory() const { return max_device_memory_; }
void set_max_device_memory(float max_device_memory) { max_device_memory_ = max_device_memory; }
void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; }
bool enable_graph_kernel() const { return enable_graph_kernel_; }
bool enable_sparse() const { return enable_sparse_; }
void set_enable_sparse(bool enable_sparse) { enable_sparse_ = enable_sparse; }
static void device_seter(DeviceSeter device) { seter_ = device; } static void device_seter(DeviceSeter device) { seter_ = device; }
static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; } static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; }
std::thread tdt_print_; std::thread tdt_print_;
template <typename T>
void set_param(MsCtxParam param, const T &value) {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
}
template <typename T>
const T &get_param(MsCtxParam param) const {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
}
template <typename T>
void increase_param(MsCtxParam param) {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
}
template <typename T>
void decrease_param(MsCtxParam param) {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
}
private: private:
inline static DeviceSeter seter_ = nullptr; inline static DeviceSeter seter_ = nullptr;
inline static DeviceTypeSeter device_type_seter_ = nullptr; inline static DeviceTypeSeter device_type_seter_ = nullptr;
static std::shared_ptr<MsContext> inst_context_; static std::shared_ptr<MsContext> inst_context_;
static std::map<std::string, MsBackendPolicy> policy_map_; static std::map<std::string, MsBackendPolicy> policy_map_;
bool bool_params_[MsCtxParam::NUM_BOOL_PARAMS];
int int_params_[MsCtxParam::NUM_INT_PARAMS];
uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS];
float float_params_[MsCtxParam::NUM_FLOAT_PARAMS];
std::string string_params_[MsCtxParam::NUM_STRING_PARAMS];
MsBackendPolicy backend_policy_; MsBackendPolicy backend_policy_;
std::string device_target_;
uint32_t device_id_;
uint32_t max_call_depth_;
int execution_mode_;
bool enable_pynative_infer_;
bool enable_pynative_hook_;
bool save_graphs_flag_;
std::string save_graphs_path_;
uint32_t tsd_ref_;
uint32_t ge_ref_;
bool enable_task_sink_;
bool enable_hccl_;
bool precompile_only_;
bool ir_fusion_flag_;
bool auto_mixed_precision_flag_;
bool enable_reduce_precision_;
bool enable_loop_sink_;
bool enable_mem_reuse_;
bool enable_gpu_summary_;
bool enable_dump_;
std::string save_dump_path_;
bool is_multi_graph_sink_;
bool is_pynative_ge_init_;
bool enable_dynamic_mem_pool_;
std::string graph_memory_max_size_;
std::string variable_memory_max_size_;
bool profiling_mode_;
std::string profiling_options_;
bool check_bprop_flag_;
float max_device_memory_;
std::string print_file_path_;
bool enable_graph_kernel_;
bool enable_sparse_;
}; };
// set method implementation for type bool/int/uint32_t/float/std::string
template <>
inline void MsContext::set_param<bool>(MsCtxParam param, const bool &value) {
bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN] = value;
}
template <>
inline void MsContext::set_param<int>(MsCtxParam param, const int &value) {
int_params_[param - MS_CTX_TYPE_INT_BEGIN] = value;
}
template <>
inline void MsContext::set_param<uint32_t>(MsCtxParam param, const uint32_t &value) {
uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN] = value;
}
template <>
inline void MsContext::set_param<float>(MsCtxParam param, const float &value) {
float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN] = value;
}
template <>
inline void MsContext::set_param<std::string>(MsCtxParam param, const std::string &value) {
if (seter_ != nullptr && param == MS_CTX_DEVICE_TARGET) {
MS_LOG(INFO) << "ms set context device target:" << value;
seter_(value);
}
string_params_[param - MS_CTX_TYPE_STRING_BEGIN] = value;
}
// get method implementation for type bool/int/uint32_t/float/std::string
template <>
inline const bool &MsContext::get_param<bool>(MsCtxParam param) const {
return bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN];
}
template <>
inline const int &MsContext::get_param<int>(MsCtxParam param) const {
return int_params_[param - MS_CTX_TYPE_INT_BEGIN];
}
template <>
inline const uint32_t &MsContext::get_param<uint32_t>(MsCtxParam param) const {
return uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN];
}
template <>
inline const float &MsContext::get_param<float>(MsCtxParam param) const {
return float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN];
}
template <>
inline const std::string &MsContext::get_param<std::string>(MsCtxParam param) const {
return string_params_[param - MS_CTX_TYPE_STRING_BEGIN];
}
// increate method implementation for type uint32_t
template <>
inline void MsContext::increase_param<uint32_t>(MsCtxParam param) {
uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]++;
}
// decreate method implementation for type uint32_t
template <>
inline void MsContext::decrease_param<uint32_t>(MsCtxParam param) {
uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]--;
}
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_MS_CONTEXT_H_ #endif // MINDSPORE_CORE_UTILS_MS_CONTEXT_H_

View File

@ -42,7 +42,7 @@ class TestOptLib : public UT::Common {
parse::data_converter::ClearObjectCache(); parse::data_converter::ClearObjectCache();
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode); ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
} }
FuncGraphPtr RunTransform(FuncGraphPtr gbefore, const SubstitutionList &transform) { FuncGraphPtr RunTransform(FuncGraphPtr gbefore, const SubstitutionList &transform) {
equiv_node.clear(); equiv_node.clear();

View File

@ -112,7 +112,7 @@ TEST_F(TestHWInsertTransOp, test_insert_trans_op_for_single_output) {
*/ */
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode); ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
auto fg = GetSingleOutputGraph("test_insert_trans_op_for_single_output", "before", "NC1HWC0"); auto fg = GetSingleOutputGraph("test_insert_trans_op_for_single_output", "before", "NC1HWC0");
// Do insert_trans_op_ pass of hardware opt // Do insert_trans_op_ pass of hardware opt
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>(); auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();

View File

@ -112,7 +112,7 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect {
TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_single_output) { TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_single_output) {
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode); ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
auto kg = GetSingleOutputGraph("test_remove_internal_output_trans_op_for_single_output", "before"); auto kg = GetSingleOutputGraph("test_remove_internal_output_trans_op_for_single_output", "before");
// insert trans op for output // insert trans op for output
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>(); auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();

View File

@ -104,7 +104,7 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
*/ */
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode); ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transdata_split_fraz_nchw", "before"); FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transdata_split_fraz_nchw", "before");
std::vector<int> shp{2, 4, 8, 16}; std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);

View File

@ -83,7 +83,7 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
*/ */
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode); ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_transdata_fusion", "before"); FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_transdata_fusion", "before");
std::vector<int> shp{2, 4, 8, 16}; std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);

View File

@ -76,7 +76,7 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) {
*/ */
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode); ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_eliminate_5to4_4to5", "before"); FuncGraphPtr g = getPyFun_.CallAndParseRet("test_eliminate_5to4_4to5", "before");
// Renormalize func_graph to infer and set shape and type information. // Renormalize func_graph to infer and set shape and type information.
std::vector<int> shp{2, 32, 224, 224}; std::vector<int> shp{2, 32, 224, 224};

View File

@ -71,15 +71,15 @@ OpExecInfoPtr ConstructOpExecInfo() {
TEST_F(TestPynativeExecute, TestCreateContext) { TEST_F(TestPynativeExecute, TestCreateContext) {
auto ctx3 = MsContext::GetInstance(); auto ctx3 = MsContext::GetInstance();
ASSERT_EQ(ctx3->backend_policy(), "vm"); ASSERT_EQ(ctx3->backend_policy(), "vm");
ASSERT_EQ(ctx3->device_target(), "CPU"); ASSERT_EQ(ctx3->get_param<std::string>(MS_CTX_DEVICE_TARGET), "CPU");
ctx3->set_backend_policy("ge_only"); ctx3->set_backend_policy("ge_only");
ctx3->set_device_target("GPU"); ctx3->set_param<std::string>(MS_CTX_DEVICE_TARGET, "GPU");
auto ctx4 = MsContext::GetInstance(); auto ctx4 = MsContext::GetInstance();
ASSERT_EQ(ctx3.get(), ctx4.get()); ASSERT_EQ(ctx3.get(), ctx4.get());
ASSERT_EQ(ctx4->backend_policy(), "ge_only"); ASSERT_EQ(ctx4->backend_policy(), "ge_only");
ASSERT_EQ(ctx4->device_target(), "GPU"); ASSERT_EQ(ctx4->get_param<std::string>(MS_CTX_DEVICE_TARGET), "GPU");
} }
TEST_F(TestPynativeExecute, TestDefaultContext) { TEST_F(TestPynativeExecute, TestDefaultContext) {