forked from OSSInnovation/mindspore
!5352 refactor ms_context implementation
Merge pull request !5352 from fary86/refactor_context_interface
This commit is contained in:
commit
8d41931456
|
@ -25,7 +25,7 @@ bool HcomAllBroadCastKernel::Launch(const std::vector<AddressPtr> &inputs,
|
|||
const std::vector<AddressPtr> & /*outputs*/, void *stream_ptr) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->enable_task_sink()) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
|
||||
return true;
|
||||
}
|
||||
if (inputs.empty() || hccl_data_type_list_.empty()) {
|
||||
|
|
|
@ -24,7 +24,7 @@ bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &inputs, const st
|
|||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->enable_task_sink()) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
|
||||
return true;
|
||||
}
|
||||
if (inputs.empty() || hccl_data_type_list_.empty()) {
|
||||
|
|
|
@ -24,7 +24,7 @@ bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const st
|
|||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->enable_task_sink()) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
|
||||
return true;
|
||||
}
|
||||
if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) {
|
||||
|
|
|
@ -25,7 +25,7 @@ bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &inputs,
|
|||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->enable_task_sink()) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
|
||||
return true;
|
||||
}
|
||||
if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) {
|
||||
|
|
|
@ -101,7 +101,8 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
|
|||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->enable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL) &&
|
||||
IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) {
|
||||
kernel_type = KernelType::AKG_KERNEL;
|
||||
}
|
||||
|
||||
|
|
|
@ -328,7 +328,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im
|
|||
}
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool is_gpu = (context->device_target() == kGPUDevice);
|
||||
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
|
||||
if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) {
|
||||
MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
|
||||
<< ", current op num: " << op_info_.size();
|
||||
|
|
|
@ -249,8 +249,8 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
|
|||
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
@ -262,7 +262,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
}
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
|
||||
if (context_ptr->execution_mode() == kPynativeMode) {
|
||||
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
|
||||
} else {
|
||||
|
@ -276,7 +276,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
AddAscendIRFusionRulesPass(ir_fusion_pm.get());
|
||||
AddAscendIRFusionPass(ir_fusion_pm.get());
|
||||
|
||||
if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) &&
|
||||
ConfigManager::GetInstance().iter_num() > 1) {
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForGetNext>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||
|
@ -296,12 +297,12 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!context_ptr->ir_fusion_flag()) {
|
||||
if (!context_ptr->get_param<bool>(MS_CTX_IR_FUSION_FLAG)) {
|
||||
MS_LOG(INFO) << "IRFusion is not enable, skip";
|
||||
return;
|
||||
}
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
@ -331,8 +332,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
|||
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
@ -367,7 +368,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
auto other2_pm = std::make_shared<PassManager>("other2_pm");
|
||||
other2_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
other2_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||
if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) &&
|
||||
ConfigManager::GetInstance().iter_num() > 1) {
|
||||
other2_pm->AddPass(std::make_shared<GetnextMemcpyElimination>());
|
||||
}
|
||||
other2_pm->AddPass(std::make_shared<CheckConsistency>());
|
||||
|
@ -388,11 +390,11 @@ void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &ke
|
|||
bool is_before_kernel_select) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->enable_graph_kernel())) {
|
||||
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||
return;
|
||||
}
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
@ -418,11 +420,11 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kern
|
|||
bool is_before_kernel_select) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->enable_graph_kernel())) {
|
||||
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||
return;
|
||||
}
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
@ -447,11 +449,11 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kern
|
|||
void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->enable_graph_kernel())) {
|
||||
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||
return;
|
||||
}
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
@ -473,12 +475,12 @@ void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &ke
|
|||
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!context_ptr->ir_fusion_flag()) {
|
||||
if (!context_ptr->get_param<bool>(MS_CTX_IR_FUSION_FLAG)) {
|
||||
MS_LOG(INFO) << "UBFusion is not enable, skip";
|
||||
return;
|
||||
}
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
|
|
@ -53,7 +53,8 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
|
|||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) {
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
|
||||
!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK)) {
|
||||
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
|
||||
return new_node;
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->execution_mode() == kPynativeMode) {
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
return RectifyKernelInfoInPynativeProcess(node);
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) {
|
||||
|
|
|
@ -33,8 +33,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id();
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
|
|
@ -392,7 +392,8 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
|
|||
bool IsNopNode(const AnfNodePtr &node) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->device_target() != kAscendDevice && context_ptr->device_target() != kGPUDevice) {
|
||||
if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice &&
|
||||
context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
||||
return false;
|
||||
}
|
||||
static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,
|
||||
|
|
|
@ -40,8 +40,8 @@ bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr>
|
|||
}
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
|
|
@ -114,7 +114,7 @@ const AnfNodePtr ConstToAttrStridedSliceGradPass::Process(const FuncGraphPtr &gr
|
|||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
||||
if (ms_context->device_target() == kAscendDevice) {
|
||||
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
|
||||
if (!CheckAttrs(strided_slice_grad)) {
|
||||
MS_LOG(INFO) << "Check strided_slice_grad's attrs failed, graph not changed";
|
||||
return nullptr;
|
||||
|
|
|
@ -359,11 +359,11 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
|
|||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
if (context_ptr->save_graphs_flag()) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
std::string file_path = save_graphs_path + "/after_erase_label_and_parameter.ir";
|
||||
DumpIR(file_path, root_graph.get());
|
||||
}
|
||||
|
|
|
@ -253,7 +253,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
|
|||
debugger_->PreExecute(graph);
|
||||
}
|
||||
#endif
|
||||
if (ms_context->precompile_only()) {
|
||||
if (ms_context->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
|
||||
MS_LOG(INFO) << "Precompile only, stop in build kernel step";
|
||||
} else {
|
||||
// alloc memory, including static memory and dynamic memory
|
||||
|
@ -278,8 +278,8 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
|
|||
child_graph->SetExecOrderByDefault();
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
@ -436,7 +436,7 @@ void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
|
|||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->execution_mode() == kGraphMode) {
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
if (raise_precision_count > 0) {
|
||||
MS_LOG(WARNING) << "There has " << raise_precision_count
|
||||
<< " node/nodes used raise precision to selected the kernel!";
|
||||
|
@ -481,8 +481,8 @@ void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph);
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
@ -601,11 +601,11 @@ void AscendSession::DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs)
|
|||
#ifdef ENABLE_DUMP_IR
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
if (!save_graphs) {
|
||||
return;
|
||||
}
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
@ -733,7 +733,7 @@ void AscendSession::MergeGraphExecOrder() {
|
|||
if (graph_order.size() > 1) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!context_ptr->enable_task_sink()) {
|
||||
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
|
||||
MS_LOG(EXCEPTION) << "Control sink network should run with task-sink mode!";
|
||||
}
|
||||
}
|
||||
|
@ -920,8 +920,8 @@ void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<st
|
|||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs) {
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
|
@ -947,7 +947,7 @@ void AscendSession::SelectKernel(NotNull<KernelGraphPtr> root_graph) {
|
|||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->execution_mode() == kGraphMode) {
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
if (raise_precision_count > 0) {
|
||||
MS_LOG(WARNING) << "There are " << raise_precision_count
|
||||
<< " node/nodes used raise precision to selected the kernel!";
|
||||
|
@ -992,8 +992,8 @@ void AscendSession::RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph,
|
|||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs) {
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
|
|
|
@ -76,7 +76,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
||||
if (!CheckInModeBlackList(kernel_graph) && context_ptr->execution_mode() != kPynativeMode) {
|
||||
if (!CheckInModeBlackList(kernel_graph) && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
||||
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
|
||||
|
@ -154,7 +154,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
|||
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
||||
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
||||
bool need_sync = false;
|
||||
if (ms_context->enable_pynative_infer()) {
|
||||
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
|
||||
if (tensor_address == nullptr || tensor_address != device_address) {
|
||||
need_sync = true;
|
||||
}
|
||||
|
@ -223,7 +223,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
|
|||
// Prepare ms context info for dump .pb graph
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
// Optimize
|
||||
Optimize(graph);
|
||||
// Select kernel build info
|
||||
|
@ -290,7 +290,7 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
|
|||
// Summary
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->enable_gpu_summary()) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY)) {
|
||||
Summary(kernel_graph.get());
|
||||
}
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
|
|
|
@ -268,7 +268,7 @@ void MSInferSession::RegAllOp() {
|
|||
return;
|
||||
}
|
||||
Initialized = true;
|
||||
MsContext::GetInstance()->set_execution_mode(kGraphMode);
|
||||
MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
Py_Initialize();
|
||||
auto c_expression = PyImport_ImportModule("mindspore._c_expression");
|
||||
if (c_expression == nullptr) {
|
||||
|
@ -357,13 +357,13 @@ Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
|
|||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return FAILED;
|
||||
}
|
||||
ms_context->set_execution_mode(kGraphMode);
|
||||
ms_context->set_device_id(device_id);
|
||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
|
||||
auto ajust_device = AjustTargetName(device);
|
||||
if (ajust_device == "") {
|
||||
return FAILED;
|
||||
}
|
||||
ms_context->set_device_target(device);
|
||||
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, device);
|
||||
if (!context::OpenTsd(ms_context)) {
|
||||
MS_LOG(ERROR) << "Session init OpenTsd failed!";
|
||||
return FAILED;
|
||||
|
|
|
@ -93,10 +93,11 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
|
|||
// if in paynative mode,data only copyed to host when user want to print data
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->execution_mode() != kPynativeMode && ms_context->device_target() != kGPUDevice) {
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
|
||||
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
||||
tensor->set_need_sync(true);
|
||||
}
|
||||
if (ms_context->execution_mode() != kPynativeMode) {
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
||||
tensor->SetNeedWait(true);
|
||||
}
|
||||
tensor->set_dirty(false);
|
||||
|
@ -938,7 +939,7 @@ bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor
|
|||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0);
|
||||
if (ms_context->enable_pynative_infer()) {
|
||||
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
|
||||
return tensor->device_address().get() == nullptr || tensor->device_address() != device_address;
|
||||
}
|
||||
if (tensor->is_dirty()) {
|
||||
|
@ -979,7 +980,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
||||
if (ms_context->execution_mode() == kPynativeMode ||
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
|
||||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
|
||||
tensor->set_device_address(device_address);
|
||||
}
|
||||
|
@ -1177,7 +1178,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
|
|||
if (backend_anf != nullptr) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->execution_mode() == kPynativeMode) {
|
||||
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
return backend_anf;
|
||||
}
|
||||
|
||||
|
|
|
@ -118,7 +118,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
debugger_ = Debugger::GetInstance();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
debugger_->Init(device_id_, ms_context->device_target());
|
||||
debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ bool DataDumpParser::DumpEnabled() const {
|
|||
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (context->execution_mode() == kPynativeMode) {
|
||||
if (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
MS_LOG(EXCEPTION) << "[DataDump] PyNative mode not support data dump";
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -142,7 +142,7 @@ void Debugger::EnableDebugger() {
|
|||
// switch memory reuse on or off
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
context_ptr->set_enable_mem_reuse(partial_memory_);
|
||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, partial_memory_);
|
||||
// print some message about memory reuse to user
|
||||
if (partial_memory_) {
|
||||
MS_LOG(WARNING) << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
|
||||
|
|
|
@ -530,7 +530,7 @@ void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) {
|
|||
MS_LOG(ERROR) << "ms_context is nullptr";
|
||||
return;
|
||||
}
|
||||
auto save_graphs_path = ms_context->save_graphs_path();
|
||||
auto save_graphs_path = ms_context->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
|
|
@ -112,7 +112,7 @@ bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) {
|
|||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
// dump_enable_ is true, close mem reuse
|
||||
context_ptr->set_enable_mem_reuse(!dump_enable_);
|
||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, !dump_enable_);
|
||||
trans_flag_ = trans_flag;
|
||||
dump_mode_ = mode;
|
||||
dump_path_ = path;
|
||||
|
@ -135,7 +135,7 @@ bool Dump::SetDumpConfFromJsonFile() {
|
|||
}
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto id = context_ptr->device_id();
|
||||
auto id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
char real_path[PATH_MAX] = {0};
|
||||
if (nullptr == realpath(config_path_str, real_path)) {
|
||||
MS_LOG(ERROR) << "Env e2e dump path error, " << config_path_str;
|
||||
|
|
|
@ -34,7 +34,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
|
|||
manager_ptr->AddFuncGraph(func_graph);
|
||||
|
||||
auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) {
|
||||
if (MsContext::GetInstance()->is_multi_graph_sink()) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
|
||||
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
|
||||
f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
}
|
||||
|
|
|
@ -182,7 +182,7 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp
|
|||
void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool check_bprop_flag = context->check_bprop_flag();
|
||||
bool check_bprop_flag = context->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG);
|
||||
// Skip checking if check_bprop not set
|
||||
if (!check_bprop_flag) {
|
||||
return;
|
||||
|
|
|
@ -29,7 +29,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
|
|||
PConstant const_2(node);
|
||||
PConstant any_const(node);
|
||||
|
||||
if (MsContext::GetInstance()->execution_mode() != kPynativeMode) {
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
||||
MATCH_REPLACE(node, x + zero_, x); // Add by zero
|
||||
MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero
|
||||
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero
|
||||
|
@ -41,7 +41,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
|
|||
}
|
||||
// Prim Eliminate (identity)
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x);
|
||||
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -75,7 +75,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
|
|||
}
|
||||
|
||||
AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
return nullptr;
|
||||
}
|
||||
PatternNode x, y;
|
||||
|
|
|
@ -181,7 +181,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
}
|
||||
};
|
||||
use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func();
|
||||
if (is_on_debug_ && MsContext::GetInstance()->save_graphs_flag()) {
|
||||
if (is_on_debug_ && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
|
||||
auto fg_name =
|
||||
"opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
|
||||
|
|
|
@ -217,8 +217,8 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
|
|||
|
||||
void DrawNode(string name, AnfNodePtr node) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
bool save_graphs = context_ptr->save_graphs_flag();
|
||||
auto save_graphs_path = context_ptr->save_graphs_path();
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ std::vector<PrimitivePtr> FindPrimtive(const FuncGraphPtr &graph, const std::str
|
|||
}
|
||||
|
||||
void DumpGraph(const FuncGraphPtr &root, const std::string &name) {
|
||||
if (MsContext::GetInstance()->save_graphs_flag()) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
draw::Draw(name + ".dot", root);
|
||||
DumpIR(name + ".ir", root);
|
||||
ExportIR(name + ".dat", "0", root);
|
||||
|
|
|
@ -69,7 +69,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
|||
struct timeval start_time, end_time;
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
|
||||
if (MsContext::GetInstance()->save_graphs_flag()) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root);
|
||||
}
|
||||
MS_LOG(INFO) << "Now entering step auto parallel";
|
||||
|
|
|
@ -271,7 +271,7 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
|
|||
if (!result) {
|
||||
MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first;
|
||||
}
|
||||
if (MsContext::GetInstance()->save_graphs_flag() && res->func_graph() != nullptr) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && res->func_graph() != nullptr) {
|
||||
auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first;
|
||||
auto func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -295,20 +295,20 @@ bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res,
|
|||
|
||||
static bool IsCtrlSink() {
|
||||
auto ms_ctx = MsContext::GetInstance();
|
||||
if (ms_ctx->execution_mode() != kGraphMode) {
|
||||
if (ms_ctx->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string device_target = ms_ctx->device_target();
|
||||
std::string device_target = ms_ctx->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (device_target != kAscendDevice) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ms_ctx->enable_task_sink()) {
|
||||
if (!ms_ctx->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ms_ctx->is_multi_graph_sink()) {
|
||||
if (!ms_ctx->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -325,13 +325,13 @@ bool TaskEmitAction(const ResourcePtr &res) {
|
|||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (CompileGraphs::ContainMixedTarget(func_graph)) {
|
||||
bc_ptr->set_is_multi_graph_sink(false);
|
||||
context_ptr->set_is_multi_graph_sink(false);
|
||||
context_ptr->set_loop_sink_flag(false);
|
||||
} else if (context_ptr->execution_mode() != kPynativeMode) {
|
||||
std::string device_target = context_ptr->device_target();
|
||||
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
|
||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
|
||||
} else if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
||||
std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (device_target == kAscendDevice && backend != kMsVm) {
|
||||
bc_ptr->set_is_multi_graph_sink(true);
|
||||
context_ptr->set_is_multi_graph_sink(true);
|
||||
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ inline std::string GetFilePathName(const std::string &file_name) {
|
|||
if (ms_context == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ms_context is nullptr";
|
||||
}
|
||||
auto save_graphs_path = ms_context->save_graphs_path();
|
||||
auto save_graphs_path = ms_context->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
|
|
@ -48,6 +48,56 @@ using OpLib = mindspore::kernel::OpLib;
|
|||
using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy;
|
||||
using ParallelContext = mindspore::parallel::ParallelContext;
|
||||
using CostModelContext = mindspore::parallel::CostModelContext;
|
||||
using mindspore::MsCtxParam;
|
||||
|
||||
namespace mindspore {
|
||||
void MsCtxSetParameter(std::shared_ptr<MsContext> ctx, MsCtxParam param, const py::object &value) {
|
||||
MS_LOG(DEBUG) << "set param(" << param << ") with value '" << py::str(value) << "' of type '"
|
||||
<< py::str(value.get_type()) << "'.";
|
||||
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END && py::isinstance<py::bool_>(value)) {
|
||||
ctx->set_param<bool>(param, value.cast<bool>());
|
||||
return;
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END && py::isinstance<py::int_>(value)) {
|
||||
ctx->set_param<int>(param, value.cast<int>());
|
||||
return;
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END && py::isinstance<py::int_>(value)) {
|
||||
ctx->set_param<uint32_t>(param, value.cast<uint32_t>());
|
||||
return;
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END && py::isinstance<py::float_>(value)) {
|
||||
ctx->set_param<float>(param, value.cast<float>());
|
||||
return;
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END && py::isinstance<py::str>(value)) {
|
||||
ctx->set_param<std::string>(param, value.cast<std::string>());
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Got illegal param " << param << " and value with type " << py::str(value.get_type());
|
||||
}
|
||||
|
||||
py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam param) {
|
||||
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END) {
|
||||
return py::bool_(ctx->get_param<bool>(param));
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END) {
|
||||
return py::int_(ctx->get_param<int>(param));
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END) {
|
||||
return py::int_(ctx->get_param<uint32_t>(param));
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END) {
|
||||
return py::float_(ctx->get_param<float>(param));
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END) {
|
||||
return py::str(ctx->get_param<std::string>(param));
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Got illegal param " << param << ".";
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
||||
// Interface with python
|
||||
PYBIND11_MODULE(_c_expression, m) {
|
||||
|
@ -101,53 +151,48 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
|
||||
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
|
||||
|
||||
(void)m.def("ms_ctx_get_param", &mindspore::MsCtxGetParameter, "Get value of specified paramter.");
|
||||
(void)m.def("ms_ctx_set_param", &mindspore::MsCtxSetParameter, "Set value for specified paramter.");
|
||||
|
||||
(void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic())
|
||||
.value("auto_mixed_precision_flag", MsCtxParam::MS_CTX_AUTO_MIXED_PRECISION_FLAG)
|
||||
.value("check_bprop_flag", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG)
|
||||
.value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP)
|
||||
.value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL)
|
||||
.value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY)
|
||||
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
|
||||
.value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL)
|
||||
.value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK)
|
||||
.value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE)
|
||||
.value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK)
|
||||
.value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER)
|
||||
.value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION)
|
||||
.value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE)
|
||||
.value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK)
|
||||
.value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG)
|
||||
.value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK)
|
||||
.value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT)
|
||||
.value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY)
|
||||
.value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING)
|
||||
.value("save_graphs_flag", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
|
||||
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
|
||||
.value("execution_mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
|
||||
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
|
||||
.value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
|
||||
.value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH)
|
||||
.value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS)
|
||||
.value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH)
|
||||
.value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH)
|
||||
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
|
||||
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
|
||||
.value("ge_ref", MsCtxParam::MS_CTX_GE_REF)
|
||||
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
|
||||
.value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF);
|
||||
|
||||
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(m, "MSContext")
|
||||
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
|
||||
.def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.")
|
||||
.def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.")
|
||||
.def("get_execution_mode", &mindspore::MsContext::execution_mode, "Get execution mode.")
|
||||
.def("set_execution_mode", &mindspore::MsContext::set_execution_mode, "Set execution mode.")
|
||||
.def("set_precompile_only", &mindspore::MsContext::set_precompile_only, "Set enable precompile only.")
|
||||
.def("get_precompile_only", &mindspore::MsContext::precompile_only, "Get enable precompile only.")
|
||||
.def("get_device_target", &mindspore::MsContext::device_target, "Get device target.")
|
||||
.def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.")
|
||||
.def("get_device_id", &mindspore::MsContext::device_id, "Get device id.")
|
||||
.def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.")
|
||||
.def("get_max_call_depth", &mindspore::MsContext::max_call_depth, "Get max call depth.")
|
||||
.def("set_max_call_depth", &mindspore::MsContext::set_max_call_depth, "Set max call depth.")
|
||||
.def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.")
|
||||
.def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.")
|
||||
.def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag,
|
||||
"Get whether to enable auto mixed precision.")
|
||||
.def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag,
|
||||
"Set whether to enable auto mixed precision.")
|
||||
.def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision,
|
||||
"Get whether to enable reduce precision.")
|
||||
.def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision,
|
||||
"Set whether to enable reduce precision.")
|
||||
.def("get_save_graphs_path", &mindspore::MsContext::save_graphs_path, "Get save graphs path.")
|
||||
.def("set_save_graphs_path", &mindspore::MsContext::set_save_graphs_path, "Set save graphs path.")
|
||||
.def("get_enable_dump", &mindspore::MsContext::enable_dump, "Get whether to enable dump.")
|
||||
.def("set_enable_dump", &mindspore::MsContext::set_enable_dump, "Set whether to enable dump.")
|
||||
.def("get_save_dump_path", &mindspore::MsContext::save_dump_path, "Get path to dump.")
|
||||
.def("set_save_dump_path", &mindspore::MsContext::set_save_dump_path, "Set path to dump.")
|
||||
.def("set_graph_memory_max_size", &mindspore::MsContext::set_graph_memory_max_size, "set graph memory max size.")
|
||||
.def("set_variable_memory_max_size", &mindspore::MsContext::set_variable_memory_max_size,
|
||||
"set variable memory max size")
|
||||
.def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.")
|
||||
.def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.")
|
||||
.def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.")
|
||||
.def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.")
|
||||
.def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.")
|
||||
.def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.")
|
||||
.def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.")
|
||||
.def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.")
|
||||
.def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.")
|
||||
.def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel,
|
||||
"Set the GraphKernel switch to on or off.")
|
||||
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.")
|
||||
.def("get_enable_sparse", &mindspore::MsContext::enable_sparse, "Get whether to enable sparsity.")
|
||||
.def("set_enable_sparse", &mindspore::MsContext::set_enable_sparse, "Set whether to enable sparsity.");
|
||||
.def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.");
|
||||
|
||||
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
|
||||
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
|
||||
|
|
|
@ -271,7 +271,7 @@ void InitOpt(const ResourcePtr &res) {
|
|||
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->enable_graph_kernel())) {
|
||||
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||
g_pass_opts["opt_graph_kernel_a"]->set_enable(false);
|
||||
g_pass_opts["opt_graph_kernel_b"]->set_enable(false);
|
||||
}
|
||||
|
|
|
@ -88,7 +88,7 @@ std::string GetBaseNameForIR(int stage_idx, const std::string &action_name) {
|
|||
if (ms_context == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ms_context is nullptr";
|
||||
}
|
||||
auto save_graphs_path = ms_context->save_graphs_path();
|
||||
auto save_graphs_path = ms_context->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
@ -646,7 +646,7 @@ void Pipeline::Run() {
|
|||
if (!result) {
|
||||
MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first;
|
||||
}
|
||||
if (MsContext::GetInstance()->save_graphs_flag() && resource_->func_graph() != nullptr) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && resource_->func_graph() != nullptr) {
|
||||
auto graph = resource_->func_graph();
|
||||
if (graph != nullptr) {
|
||||
user_graph = graph;
|
||||
|
@ -688,7 +688,7 @@ void Pipeline::Run() {
|
|||
MsProfile::Reset();
|
||||
#endif
|
||||
|
||||
if (MsContext::GetInstance()->save_graphs_flag() && (user_graph != nullptr)) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && (user_graph != nullptr)) {
|
||||
std::string user_graph_file = GetFilePathName("ModelDigraph.dot");
|
||||
MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file;
|
||||
draw::DrawUserFuncGraph(user_graph_file, user_graph);
|
||||
|
@ -710,7 +710,7 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef
|
|||
if (!succ) {
|
||||
MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed.";
|
||||
}
|
||||
if (MsContext::GetInstance()->execution_mode() == 0 && !converted->isa<tensor::Tensor>()) {
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == 0 && !converted->isa<tensor::Tensor>()) {
|
||||
MS_EXCEPTION(TypeError) << "For 'graph mode', the " << i << "th arg: " << converted->ToString()
|
||||
<< " is not tensor.";
|
||||
}
|
||||
|
@ -891,7 +891,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
|
|||
// Convert CNodeList to LinConvertResult.
|
||||
ConfigManager::GetInstance().set_iter_num(1);
|
||||
auto runner = convert_fn({app_init}, "");
|
||||
if (MsContext::GetInstance()->execution_mode() != kPynativeMode) {
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
||||
backend->Link(runner.graph_id);
|
||||
}
|
||||
ConfigManager::GetInstance().set_iter_num(size);
|
||||
|
@ -965,10 +965,11 @@ void InitHccl() {
|
|||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
(void)context::OpenTsd(ms_context);
|
||||
uint32_t device_id = ms_context->device_id();
|
||||
std::string device_name = ms_context->device_target();
|
||||
ms_context->set_enable_hccl(true);
|
||||
if (ms_context->backend_policy() == "ms" && ms_context->device_target() == kAscendDevice) {
|
||||
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
ms_context->set_param<bool>(MS_CTX_ENABLE_HCCL, true);
|
||||
if (ms_context->backend_policy() == "ms" &&
|
||||
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
if (!runtime_instance->Init()) {
|
||||
|
|
|
@ -214,7 +214,7 @@ bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::di
|
|||
return false;
|
||||
}
|
||||
|
||||
if (MsContext::GetInstance()->save_graphs_flag()) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
convertor.DrawComputeGraph(GetFilePathName("ge_graph.dot")); // for debug
|
||||
convertor.DrawInitGraph(GetFilePathName("init_graph.dot")); // for debug
|
||||
convertor.DrawSaveCheckpointGraph(GetFilePathName("save_checkpoint_graph.dot")); // for debug
|
||||
|
@ -244,7 +244,7 @@ FuncGraphPtr BuildDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, co
|
|||
}
|
||||
FuncGraphPtr anf_graph = info.at(phase)->func_graph;
|
||||
|
||||
if (MsContext::GetInstance()->save_graphs_flag()) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
draw::Draw(GetFilePathName("anf_graph.dot"), anf_graph); // for debug
|
||||
DumpIR(GetFilePathName("anf_graph.ir"), anf_graph, true);
|
||||
}
|
||||
|
|
|
@ -118,8 +118,9 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
<< ", current function call depth: " << engine->function_call_depth();
|
||||
AbstractBasePtr ret_base = nullptr;
|
||||
engine->IncreaseFunctionCallDepth();
|
||||
if (engine->function_call_depth() > MsContext::GetInstance()->max_call_depth()) {
|
||||
MS_LOG(EXCEPTION) << "Exceed function call depth limit " << MsContext::GetInstance()->max_call_depth() << ".";
|
||||
if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
|
||||
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
|
||||
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) << ".";
|
||||
}
|
||||
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
|
||||
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
|
||||
|
@ -409,7 +410,7 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
|
|||
bparams.push_back(SensitivityTransform(orig_func_));
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->enable_sparse();
|
||||
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
|
||||
[&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr {
|
||||
if (enable_sparse && arg_spec->isa<AbstractTensor>()) {
|
||||
|
|
|
@ -62,7 +62,7 @@ class Evaluator : public Base {
|
|||
virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->enable_sparse();
|
||||
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
|
||||
if (!enable_sparse) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -290,7 +290,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
|||
if (abs_base->isa<AbstractTensor>()) {
|
||||
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
|
||||
dic["shape"] = arg_tensor->shape()->shape();
|
||||
if (MsContext::GetInstance()->execution_mode() == kGraphMode) {
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
const auto &min_shape = arg_tensor->shape()->min_shape();
|
||||
const auto &max_shape = arg_tensor->shape()->max_shape();
|
||||
if (!min_shape.empty() && !max_shape.empty()) {
|
||||
|
|
|
@ -558,8 +558,8 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms";
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
ms_context->set_enable_pynative_infer(true);
|
||||
std::string device_target = ms_context->device_target();
|
||||
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
|
||||
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (device_target != kAscendDevice && device_target != kGPUDevice) {
|
||||
MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode";
|
||||
}
|
||||
|
@ -567,7 +567,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|||
if (session == nullptr) {
|
||||
session = session::SessionFactory::Get().Create(device_target);
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
session->Init(ms_context->device_id());
|
||||
session->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID));
|
||||
}
|
||||
|
||||
std::vector<tensor::TensorPtr> input_tensors;
|
||||
|
@ -578,7 +578,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|||
session->BuildOpAsync(op_exec_info.get(), graph_info, input_tensors, tensors_mask);
|
||||
EraseValueNodeTensor(tensors_mask, &input_tensors);
|
||||
py::tuple result = session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors);
|
||||
ms_context->set_enable_pynative_infer(false);
|
||||
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
|
||||
*status = PYNATIVE_SUCCESS;
|
||||
MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms";
|
||||
return result;
|
||||
|
@ -1308,7 +1308,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
|
|||
// Maybe exit in the pynative runing op, so need reset pynative flag.
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context != nullptr) {
|
||||
ms_context->set_enable_pynative_infer(false);
|
||||
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
|
||||
}
|
||||
ConfigManager::GetInstance().ResetIterNum();
|
||||
return;
|
||||
|
|
|
@ -89,7 +89,7 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
|
|||
MS_EXCEPTION(ValueError) << "For user define net bprop, the gradients number: " << grads.size()
|
||||
<< " is not equal to the args number: " << py_args.size() - 2 << ".";
|
||||
}
|
||||
if (MsContext::GetInstance()->check_bprop_flag()) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG)) {
|
||||
for (size_t i = 0; i < grads.size(); i++) {
|
||||
if (py::isinstance<tensor::Tensor>(py_args[i])) {
|
||||
if (!py::isinstance<tensor::Tensor>(grads[i])) {
|
||||
|
|
|
@ -154,7 +154,7 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s
|
|||
DeviceAddressPtr AssignLaunchMemory(size_t size, const std::string &format, TypeId type) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto device_id = ms_context->device_id();
|
||||
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
auto address_ptr = runtime_instance->AssignSingleOpLaunchMemory(size, format, type);
|
||||
|
@ -261,11 +261,12 @@ void AscendDeviceAddress::SyncStream() const {
|
|||
MS_LOG(INFO) << "Start!";
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->execution_mode() != kPynativeMode && !ms_context->enable_pynative_infer()) {
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
|
||||
!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
return;
|
||||
}
|
||||
auto device_id = ms_context->device_id();
|
||||
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
auto ret = runtime_instance->SyncStream();
|
||||
|
@ -348,7 +349,7 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v
|
|||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto device_id = ms_context->device_id();
|
||||
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
auto ret =
|
||||
|
@ -475,7 +476,8 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh
|
|||
std::vector<size_t> device_shape = GetDeviceShape(&host_shape);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->execution_mode() != kGraphMode && ms_context->execution_mode() != kPynativeMode &&
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode &&
|
||||
ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
|
||||
type_id_name_map.find(type_id_) != type_id_name_map.end()) {
|
||||
std::pair<std::string, std::string> type_format = std::make_pair(type_id_name_map.at(type_id_), format_);
|
||||
if (use_trans_data.find(type_format) != use_trans_data.end()) {
|
||||
|
|
|
@ -158,7 +158,7 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std
|
|||
bool AscendKernelRuntime::NeedDestroyHccl() {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!context_ptr->enable_hccl()) {
|
||||
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
|
||||
MS_LOG(INFO) << "Hccl is not enabled";
|
||||
return false;
|
||||
}
|
||||
|
@ -177,7 +177,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
|
|||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto ret = rtSetDevice(context_ptr->device_id());
|
||||
auto ret = rtSetDevice(context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID));
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast<int>(ret) << "]";
|
||||
}
|
||||
|
@ -461,12 +461,12 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
|||
MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id();
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool is_task_sink = context_ptr->enable_task_sink();
|
||||
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
|
||||
if (!is_task_sink) {
|
||||
return true;
|
||||
}
|
||||
#ifdef MEM_REUSE_DEBUG
|
||||
if (!context_ptr->enable_mem_reuse()) {
|
||||
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_REUSE)) {
|
||||
// Get normal graph ir for memreuse
|
||||
mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph);
|
||||
}
|
||||
|
@ -518,7 +518,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
|
|||
MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id();
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool is_task_sink = context_ptr->enable_task_sink();
|
||||
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
|
||||
if (!is_task_sink) {
|
||||
return true;
|
||||
}
|
||||
|
@ -658,7 +658,7 @@ bool AscendKernelRuntime::InitDevice() {
|
|||
MS_LOG(ERROR) << "Get MsContext instance failed";
|
||||
return false;
|
||||
}
|
||||
if (context_ptr->enable_hccl()) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
|
||||
if (!HcclInit()) {
|
||||
MS_LOG(ERROR) << "HcclInit init failed";
|
||||
return false;
|
||||
|
@ -746,7 +746,7 @@ bool AscendKernelRuntime::DestroyHccl() {
|
|||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Hccl destroy successful, status = " << res << ".";
|
||||
context_ptr->set_enable_hccl(false);
|
||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_HCCL, false);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ void AscendMemoryManager::MallocDeviceMemory() {
|
|||
uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto variable_memory_max_size = context->variable_memory_max_size();
|
||||
auto variable_memory_max_size = context->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
|
||||
if (variable_memory_max_size == "0") {
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -1373,7 +1373,7 @@ vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::it
|
|||
bool AscendStreamAssign::IsTaskSink() {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (!ms_context->enable_task_sink()) {
|
||||
if (!ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
|
||||
MS_LOG(INFO) << "Task sink mode is not enable";
|
||||
return false;
|
||||
} else {
|
||||
|
|
|
@ -117,7 +117,7 @@ void DataDumper::SetOpMappingInfo(NotNull<aicpu::dump::OpMappingInfo *> dump_inf
|
|||
if (!dump_path.has_value()) {
|
||||
MS_LOG(EXCEPTION) << "Dump path invalid";
|
||||
}
|
||||
auto device_id = context_ptr->device_id();
|
||||
auto device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
dump_info->set_dump_path("/" + dump_path.value() + "_" + std::to_string(device_id) + "/");
|
||||
MS_LOG(INFO) << "[DataDump] dump_path:" << dump_path.value();
|
||||
|
||||
|
|
|
@ -363,7 +363,7 @@ void PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index,
|
|||
*precision_reduce = false;
|
||||
return;
|
||||
}
|
||||
if (context_ptr->enable_reduce_precision()) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION)) {
|
||||
selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
|
||||
kernel_support_datatype, &kernel_match_datatype_idx_copy);
|
||||
}
|
||||
|
|
|
@ -117,7 +117,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
|
|||
}
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
const string prof_options_str = context->profiling_options();
|
||||
const string prof_options_str = context->get_param<std::string>(MS_CTX_PROFILING_OPTIONS);
|
||||
std::vector<string> opts = Split(prof_options_str, ':');
|
||||
if (opts.empty()) {
|
||||
MS_LOG(WARNING) << "Profiling is enabled, but profiling option is not set!";
|
||||
|
|
|
@ -41,7 +41,7 @@ class ProfilingManager {
|
|||
inline bool IsProfiling() const {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return context->enable_profiling();
|
||||
return context->get_param<bool>(MS_CTX_ENABLE_PROFILING);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
|
|
@ -342,12 +342,12 @@ void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids,
|
|||
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
TaskDescReporter task_reporter(context->device_id(), "vm.task_desc_info", ret->second);
|
||||
TaskDescReporter task_reporter(context->get_param<uint32_t>(MS_CTX_DEVICE_ID), "vm.task_desc_info", ret->second);
|
||||
task_reporter.set_task_ids(task_ids);
|
||||
task_reporter.set_stream_ids(stream_ids);
|
||||
task_reporter.ReportData();
|
||||
|
||||
GraphDescReporter graph_reporter(context->device_id(), "vm.graph_desc_info", ret->second);
|
||||
GraphDescReporter graph_reporter(context->get_param<uint32_t>(MS_CTX_DEVICE_ID), "vm.graph_desc_info", ret->second);
|
||||
graph_profiling_cnode_.erase(ret);
|
||||
graph_reporter.ReportData();
|
||||
|
||||
|
@ -357,7 +357,7 @@ void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids,
|
|||
MS_LOG(ERROR) << "Graph id not found in graph_point";
|
||||
return;
|
||||
}
|
||||
PointReporter point_reporter(context->device_id(), "vm.point");
|
||||
PointReporter point_reporter(context->get_param<uint32_t>(MS_CTX_DEVICE_ID), "vm.point");
|
||||
for (const auto &point : point_iter->second) {
|
||||
point_reporter.AddReportData(point);
|
||||
}
|
||||
|
|
|
@ -416,7 +416,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
|
|||
mem_manager_->ResetDynamicMemory();
|
||||
AssignStaticMemoryInput(graph);
|
||||
AssignStaticMemoryValueNode(graph);
|
||||
bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool();
|
||||
bool is_enable_dynamic_mem = context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL);
|
||||
if (is_enable_dynamic_mem) {
|
||||
// Use the dynamic memory pool.
|
||||
InitKernelRefCount(graph);
|
||||
|
@ -435,8 +435,8 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
|
|||
bool ret = true;
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool();
|
||||
bool is_enable_pynative_infer = context_ptr->enable_pynative_infer();
|
||||
bool is_enable_dynamic_mem = context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL);
|
||||
bool is_enable_pynative_infer = context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
|
||||
if (is_enable_dynamic_mem && !is_enable_pynative_infer) {
|
||||
auto graph_id = graph->graph_id();
|
||||
auto iter = mem_swap_map_.find(graph_id);
|
||||
|
|
|
@ -29,7 +29,7 @@ bool GPUMemoryAllocator::Init() {
|
|||
size_t free_size = CudaDriver::free_mem_size();
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
limited_device_memory_ = context_ptr->max_device_memory();
|
||||
limited_device_memory_ = context_ptr->get_param<float>(MS_CTX_MAX_DEVICE_MEMORY);
|
||||
available_device_memory_ = FloatToSize(limited_device_memory_ * 1024 * 1024 * 1024);
|
||||
if (total_size > 0 && free_size > 0 && available_device_memory_ > 0) {
|
||||
MS_LOG(INFO) << "GPU device total memory size " << total_size << ", current free memory size " << free_size
|
||||
|
@ -44,7 +44,7 @@ bool GPUMemoryAllocator::Init() {
|
|||
void GPUMemoryAllocator::CheckMaxDeviceMemory() const {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto max_device_memory = context_ptr->max_device_memory();
|
||||
auto max_device_memory = context_ptr->get_param<float>(MS_CTX_MAX_DEVICE_MEMORY);
|
||||
// Currently not support modifying the max device memory.
|
||||
if (limited_device_memory_ != max_device_memory) {
|
||||
MS_LOG(EXCEPTION)
|
||||
|
|
|
@ -37,7 +37,7 @@ void GPUMemoryManager::MallocDeviceMemory() {
|
|||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
// If use the dynamic memory pool, then alloc the first memory block to init.
|
||||
if (context_ptr->enable_dynamic_mem_pool()) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) {
|
||||
auto device_addr = MallocMemFromMemPool(1);
|
||||
if (!device_addr) {
|
||||
MS_LOG(EXCEPTION) << "Dynamic memory pool init error.";
|
||||
|
@ -65,7 +65,7 @@ void GPUMemoryManager::FreeDeviceMemory() {
|
|||
uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->enable_dynamic_mem_pool()) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) {
|
||||
auto device_ptr = MallocMemFromMemPool(size);
|
||||
MS_EXCEPTION_IF_NULL(device_ptr);
|
||||
return AddressOffset(device_ptr, 0);
|
||||
|
|
|
@ -162,7 +162,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
|
|||
bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->execution_mode() == kPynativeMode) {
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
return false;
|
||||
}
|
||||
if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) {
|
||||
|
|
|
@ -60,8 +60,8 @@ void KernelAdjust::ReorderGetNext(const std::shared_ptr<session::KernelGraph> &k
|
|||
bool KernelAdjust::NeedInsertSwitch() {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
return (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() &&
|
||||
ConfigManager::GetInstance().iter_num() > 1);
|
||||
return (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) &&
|
||||
context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && ConfigManager::GetInstance().iter_num() > 1);
|
||||
}
|
||||
|
||||
CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr,
|
||||
|
|
|
@ -50,7 +50,7 @@ bool KernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
|
|||
struct timeval start_time, end_time;
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
#endif
|
||||
bool is_task_sink = context_ptr->enable_task_sink();
|
||||
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
|
||||
if (is_task_sink) {
|
||||
ret = RunTask(graph);
|
||||
} else {
|
||||
|
@ -502,7 +502,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
|
|||
MS_LOG(INFO) << "communication op addr exist";
|
||||
continue;
|
||||
}
|
||||
if (context_ptr->enable_hccl()) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
|
||||
mem_size = mem_manager_->GetCommonAlignSize(mem_size);
|
||||
}
|
||||
total_size += mem_size;
|
||||
|
@ -646,7 +646,8 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
|
|||
DeviceAddressPtr address = nullptr;
|
||||
address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, node_size)) {
|
||||
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
|
||||
!mem_manager_->MallocMemFromMemPool(address, node_size)) {
|
||||
MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << node_size;
|
||||
} else if (mem_manager_->MallocMem(kStaticMem, node_size, address) == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size;
|
||||
|
@ -682,7 +683,8 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
|
|||
DeviceAddressPtr address = nullptr;
|
||||
address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
if (ms_context->enable_pynative_infer() && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
|
||||
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
|
||||
!mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
|
||||
MS_LOG(EXCEPTION) << "Cannot alloc address from memory pool when tensor size is: " << tensor_size;
|
||||
} else if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
|
||||
|
@ -701,7 +703,7 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
|
|||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool is_enable_mem_reuse = context_ptr->enable_mem_reuse();
|
||||
bool is_enable_mem_reuse = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_REUSE);
|
||||
auto mem_type = kDynamicMem;
|
||||
if (is_enable_mem_reuse) {
|
||||
mem_manager_->MallocReusedDynamicMem(graph);
|
||||
|
|
|
@ -54,7 +54,7 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, Me
|
|||
uint8_t *ptr = nullptr;
|
||||
if (AnfAlgo::IsCommunicationOp(node)) {
|
||||
bool communication_mem = false;
|
||||
if (context_ptr->enable_hccl()) {
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
|
||||
communication_mem = true;
|
||||
}
|
||||
if (type == kStaticMem) {
|
||||
|
|
|
@ -1070,7 +1070,7 @@ void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNod
|
|||
convertor.inputs_ = inputs;
|
||||
(void)convertor.ConvertAllNode().BuildGraph();
|
||||
std::string name = graph_node->ToString() + "_ge_graph.dot";
|
||||
if (MsContext::GetInstance()->save_graphs_flag()) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
convertor.DrawComputeGraph(name);
|
||||
}
|
||||
branches_map_[node.get()] = *(convertor.df_graph_);
|
||||
|
|
|
@ -41,13 +41,13 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
|||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
|
||||
if (ms_context_ptr->is_pynative_ge_init()) {
|
||||
if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ms_context_ptr->tsd_ref()) {
|
||||
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
|
||||
MS_LOG(DEBUG) << "TDT Dataset client is already opened.";
|
||||
ms_context_ptr->set_tsd_ref("++");
|
||||
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -59,7 +59,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
|||
unsigned int device_id;
|
||||
unsigned int rank_size = 1;
|
||||
|
||||
device_id = ms_context_ptr->device_id();
|
||||
device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
|
||||
auto rank_size_env = common::GetEnv("RANK_SIZE");
|
||||
if (rank_size_env.empty()) {
|
||||
|
@ -79,7 +79,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
|||
MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << ".";
|
||||
return false;
|
||||
}
|
||||
ms_context_ptr->set_tsd_ref("++");
|
||||
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
|
||||
#ifdef ENABLE_TDTQUE
|
||||
int32_t initStatus = tdt::TdtHostInit(device_id);
|
||||
if (initStatus != TDT_OK_CODE) {
|
||||
|
@ -88,7 +88,8 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
|||
}
|
||||
ms_context_ptr->tdt_print_ = std::thread(TensorPrint());
|
||||
#endif
|
||||
MS_LOG(INFO) << "Open and init tsd successful, tsd reference = " << ms_context_ptr->tsd_ref() << ".";
|
||||
MS_LOG(INFO) << "Open and init tsd successful, tsd reference = "
|
||||
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -96,12 +97,12 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
|
|||
if (ms_context_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
if (ms_context_ptr->tsd_ref() == 0) {
|
||||
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
|
||||
return true;
|
||||
}
|
||||
ms_context_ptr->set_tsd_ref("--");
|
||||
if (force || ms_context_ptr->tsd_ref() == 0) {
|
||||
ms_context_ptr->set_tsd_ref(" ");
|
||||
ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF);
|
||||
if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
|
||||
ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
|
||||
#ifdef ENABLE_TDTQUE
|
||||
int32_t stopStatus = tdt::TdtHostStop(KNpuLog);
|
||||
if (stopStatus != TDT_OK_CODE) {
|
||||
|
@ -123,17 +124,17 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
|
|||
MS_LOG(ERROR) << "tdt thread join failed: " << e.what();
|
||||
}
|
||||
#endif
|
||||
auto device_id = ms_context_ptr->device_id();
|
||||
auto device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
TDT_StatusT status = TsdClose(device_id);
|
||||
if (status != TDT_OK) {
|
||||
MS_LOG(EXCEPTION) << "Close tsd failed, status = " << status << ".";
|
||||
return false;
|
||||
}
|
||||
ms_context_ptr->set_pynative_ge_init(false);
|
||||
ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
|
||||
MS_LOG(INFO) << "Destroy and close tsd successful, status = " << status << ".";
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = " << ms_context_ptr->tsd_ref()
|
||||
<< ".";
|
||||
MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = "
|
||||
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -159,14 +160,14 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
|
|||
}
|
||||
#ifdef ENABLE_GE
|
||||
(*ge_options)["device_id"] = "0";
|
||||
(*ge_options)["ge.exec.enableDump"] = std::to_string(ms_context_ptr->enable_dump());
|
||||
(*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->save_dump_path();
|
||||
(*ge_options)["ge.exec.enableDump"] = std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP));
|
||||
(*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH);
|
||||
(*ge_options)["ge.exec.dumpMode"] = "output";
|
||||
MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->enable_dump())
|
||||
<< " and save dump path is " << ms_context_ptr->save_dump_path() << ".";
|
||||
(*ge_options)["ge.exec.profilingMode"] = std::to_string(ms_context_ptr->enable_profiling());
|
||||
if (ms_context_ptr->enable_profiling()) {
|
||||
(*ge_options)["ge.exec.profilingOptions"] = ms_context_ptr->profiling_options();
|
||||
MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP))
|
||||
<< " and save dump path is " << ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH) << ".";
|
||||
(*ge_options)["ge.exec.profilingMode"] = std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_PROFILING));
|
||||
if (ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_PROFILING)) {
|
||||
(*ge_options)["ge.exec.profilingOptions"] = ms_context_ptr->get_param<std::string>(MS_CTX_PROFILING_OPTIONS);
|
||||
}
|
||||
|
||||
(*ge_options)["rank_table_file"] = "";
|
||||
|
@ -178,12 +179,12 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
|
|||
}
|
||||
(*ge_options)["graphType"] = "1";
|
||||
|
||||
if (ms_context_ptr->graph_memory_max_size() != "0") {
|
||||
(*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->graph_memory_max_size();
|
||||
if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
|
||||
(*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
|
||||
}
|
||||
|
||||
if (ms_context_ptr->variable_memory_max_size() != "0") {
|
||||
(*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->variable_memory_max_size();
|
||||
if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
|
||||
(*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
|
||||
}
|
||||
|
||||
#if ENABLE_TRAIN == 1
|
||||
|
@ -224,7 +225,7 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
|
|||
}
|
||||
|
||||
// Enable auto mixed precision according to the context options
|
||||
if (ms_context_ptr->auto_mixed_precision_flag()) {
|
||||
if (ms_context_ptr->get_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG)) {
|
||||
(*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision";
|
||||
} else {
|
||||
(*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
|
||||
|
@ -240,7 +241,7 @@ void SetHcclOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<s
|
|||
}
|
||||
auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
|
||||
auto env_rank_id = common::GetEnv("RANK_ID");
|
||||
auto env_device_id = std::to_string(ms_context_ptr->device_id());
|
||||
auto env_device_id = std::to_string(ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID));
|
||||
if (!(env_table_file.empty() || env_rank_id.empty())) {
|
||||
MS_LOG(INFO) << "Initialize Ge for distribute parameter";
|
||||
MS_LOG(INFO) << "Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
|
||||
|
@ -275,12 +276,12 @@ bool InitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
|||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
#ifdef ENABLE_GE
|
||||
if (ms_context_ptr->is_pynative_ge_init()) {
|
||||
if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ms_context_ptr->ge_ref()) {
|
||||
ms_context_ptr->set_ge_ref("++");
|
||||
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF)) {
|
||||
ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -293,8 +294,8 @@ bool InitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
|||
MS_LOG(EXCEPTION) << "Initialize GE failed!";
|
||||
}
|
||||
}
|
||||
ms_context_ptr->set_ge_ref("++");
|
||||
MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->ge_ref() << ".";
|
||||
ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
|
||||
MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
@ -303,12 +304,13 @@ bool PynativeInitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
|||
if (ms_context_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
if (ms_context_ptr->is_pynative_ge_init() || ms_context_ptr->ge_ref() || ms_context_ptr->tsd_ref()) {
|
||||
if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT) ||
|
||||
ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
|
||||
return true;
|
||||
}
|
||||
(void)OpenTsd(ms_context_ptr);
|
||||
(void)InitGe(ms_context_ptr);
|
||||
ms_context_ptr->set_pynative_ge_init(true);
|
||||
ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, true);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -317,12 +319,12 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
|
|||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
#ifdef ENABLE_GE
|
||||
if (ms_context_ptr->ge_ref() == 0) {
|
||||
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
|
||||
return true;
|
||||
}
|
||||
ms_context_ptr->set_ge_ref("--");
|
||||
if (force || ms_context_ptr->ge_ref() == 0) {
|
||||
ms_context_ptr->set_ge_ref(" ");
|
||||
ms_context_ptr->decrease_param<uint32_t>(MS_CTX_GE_REF);
|
||||
if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
|
||||
ms_context_ptr->set_param<uint32_t>(MS_CTX_GE_REF, 0);
|
||||
try {
|
||||
DfGraphManager::GetInstance().DeleteGraphRunner();
|
||||
DfGraphManager::GetInstance().DeleteGeSession();
|
||||
|
@ -337,7 +339,8 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
|
|||
}
|
||||
ms_context_ptr->set_pynative_ge_init(false);
|
||||
} else {
|
||||
MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = " << ms_context_ptr->ge_ref() << ".";
|
||||
MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = "
|
||||
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
|
@ -347,14 +350,14 @@ bool IsTsdOpened(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
|||
if (ms_context_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
return ms_context_ptr->IsTsdOpened();
|
||||
return ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) > 0;
|
||||
}
|
||||
|
||||
bool IsGeInited(const std::shared_ptr<MsContext> &ms_context_ptr) {
|
||||
if (ms_context_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
return ms_context_ptr->IsGeInited();
|
||||
return ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) > 0;
|
||||
}
|
||||
|
||||
// Register for device type.
|
||||
|
|
|
@ -353,7 +353,7 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
|
|||
// When sparse enabled, the undetermined might be raised and eliminated in opt passes
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->enable_sparse();
|
||||
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
|
||||
if (enable_sparse) {
|
||||
return std::make_shared<abstract::AbstractUndetermined>();
|
||||
}
|
||||
|
|
|
@ -273,7 +273,7 @@ void TensorPrint::operator()() {
|
|||
prntpb::Print print;
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
std::string print_file_path = ms_context->print_file_path();
|
||||
std::string print_file_path = ms_context->get_param<std::string>(MS_CTX_PRINT_FILE_PATH);
|
||||
if (print_file_path == "") {
|
||||
while (true) {
|
||||
std::vector<tdt::DataItem> bundle;
|
||||
|
|
|
@ -59,7 +59,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
|
|||
graph_id = target_sess_->CompileGraphAsync(lst, outputs);
|
||||
}
|
||||
|
||||
if (MsContext::GetInstance()->precompile_only()) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
|
||||
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
|
||||
return result;
|
||||
}
|
||||
|
@ -180,7 +180,7 @@ void MsBackend::CreateOtherSession(const std::string &target) {
|
|||
}
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
uint32_t device_id = context_ptr->device_id();
|
||||
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
other_sess_->Init(device_id);
|
||||
other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
|
||||
other_device_ = target;
|
||||
|
|
|
@ -56,7 +56,7 @@ namespace {
|
|||
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string last_target = context_ptr->device_target();
|
||||
std::string last_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
for (auto &node : nodes) {
|
||||
if (node->isa<CNode>()) {
|
||||
std::string cur_target = GetCNodeTarget(node);
|
||||
|
@ -348,7 +348,7 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
|
|||
if (prim->name() == prim::kPrimBpropCut->name()) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
ms_context->set_enable_pynative_hook(true);
|
||||
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, true);
|
||||
}
|
||||
|
||||
if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) {
|
||||
|
@ -412,7 +412,7 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
|
|||
if (ContainMultiTarget(nodes)) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string default_target = context_ptr->device_target();
|
||||
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
nodes = SplitSort(graph, default_target);
|
||||
return SplitNodesWithTarget(nodes, graph);
|
||||
}
|
||||
|
@ -920,17 +920,17 @@ BackendPtr CreateBackend() {
|
|||
}
|
||||
|
||||
if (name == kMsConvert) {
|
||||
std::string target = context_ptr->device_target();
|
||||
uint32_t device_id = context_ptr->device_id();
|
||||
std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
auto backend = std::make_shared<MsBackend>(name, target, device_id);
|
||||
std::string device_target = MsContext::GetInstance()->device_target();
|
||||
std::string device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (device_target == kAscendDevice) {
|
||||
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
backend->set_is_multi_graph_sink(false);
|
||||
context_ptr->set_is_multi_graph_sink(false);
|
||||
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
|
||||
} else {
|
||||
backend->set_is_multi_graph_sink(true);
|
||||
context_ptr->set_is_multi_graph_sink(true);
|
||||
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
|
||||
}
|
||||
}
|
||||
return backend;
|
||||
|
|
|
@ -22,7 +22,7 @@ import threading
|
|||
from collections import namedtuple
|
||||
from types import FunctionType
|
||||
from mindspore import log as logger
|
||||
from mindspore._c_expression import MSContext
|
||||
from mindspore._c_expression import MSContext, ms_ctx_param, ms_ctx_get_param, ms_ctx_set_param
|
||||
from mindspore._checkparam import args_type_check
|
||||
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
|
||||
_reset_auto_parallel_context
|
||||
|
@ -157,9 +157,15 @@ class _Context:
|
|||
raise ValueError("Context handle is none in context!!!")
|
||||
return value
|
||||
|
||||
def get_param(self, param):
|
||||
return ms_ctx_get_param(self._context_handle, param)
|
||||
|
||||
def set_param(self, param, value):
|
||||
ms_ctx_set_param(self._context_handle, param, value)
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
return self._context_handle.get_execution_mode()
|
||||
return self.get_param(ms_ctx_param.execution_mode)
|
||||
|
||||
@mode.setter
|
||||
def mode(self, mode):
|
||||
|
@ -169,15 +175,17 @@ class _Context:
|
|||
Args:
|
||||
mode (int): GRAPH_MODE or PYNATIVE_MODE.
|
||||
"""
|
||||
self._context_handle.set_execution_mode(mode)
|
||||
if mode == PYNATIVE_MODE:
|
||||
if self.enable_debug_runtime:
|
||||
self.set_backend_policy("vm")
|
||||
self._context_switches.push(True, None)
|
||||
else:
|
||||
elif mode == GRAPH_MODE:
|
||||
if self.enable_debug_runtime:
|
||||
self.set_backend_policy("ge")
|
||||
self._context_switches.push(False, None)
|
||||
else:
|
||||
raise ValueError(f'The execution mode {mode} is invalid!')
|
||||
self.set_param(ms_ctx_param.execution_mode, mode)
|
||||
|
||||
def set_backend_policy(self, policy):
|
||||
success = self._context_handle.set_backend_policy(policy)
|
||||
|
@ -186,110 +194,106 @@ class _Context:
|
|||
|
||||
@property
|
||||
def precompile_only(self):
|
||||
return self._context_handle.get_precompile_only()
|
||||
return self.get_param(ms_ctx_param.precompile_only)
|
||||
|
||||
@precompile_only.setter
|
||||
def precompile_only(self, precompile_only):
|
||||
self._context_handle.set_precompile_only(precompile_only)
|
||||
self.set_param(ms_ctx_param.precompile_only, precompile_only)
|
||||
|
||||
@property
|
||||
def save_graphs(self):
|
||||
return self._context_handle.get_save_graphs_flag()
|
||||
return self.get_param(ms_ctx_param.save_graphs_flag)
|
||||
|
||||
@save_graphs.setter
|
||||
def save_graphs(self, save_graphs_flag):
|
||||
self._context_handle.set_save_graphs_flag(save_graphs_flag)
|
||||
self.set_param(ms_ctx_param.save_graphs_flag, save_graphs_flag)
|
||||
|
||||
@property
|
||||
def save_graphs_path(self):
|
||||
return self._context_handle.get_save_graphs_path()
|
||||
return self.get_param(ms_ctx_param.save_graphs_path)
|
||||
|
||||
@save_graphs_path.setter
|
||||
def save_graphs_path(self, save_graphs_path):
|
||||
self._context_handle.set_save_graphs_path(
|
||||
_make_directory(save_graphs_path))
|
||||
self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path))
|
||||
|
||||
@property
|
||||
def device_target(self):
|
||||
return self._context_handle.get_device_target()
|
||||
return self.get_param(ms_ctx_param.device_target)
|
||||
|
||||
@device_target.setter
|
||||
def device_target(self, target):
|
||||
success = self._context_handle.set_device_target(target)
|
||||
if not success:
|
||||
raise ValueError("Target device name is invalid!!!")
|
||||
if self.enable_debug_runtime and self.device_target == "CPU":
|
||||
valid_targets = ["CPU", "GPU", "Ascend", "Davinci"]
|
||||
if not target in valid_targets:
|
||||
raise ValueError(f"Target device name {target} is invalid! It must be one of {valid_targets}")
|
||||
if target == "Davinci":
|
||||
target = "Ascend"
|
||||
self.set_param(ms_ctx_param.device_target, target)
|
||||
if self.enable_debug_runtime and target == "CPU":
|
||||
self.set_backend_policy("vm")
|
||||
|
||||
@property
|
||||
def device_id(self):
|
||||
return self._context_handle.get_device_id()
|
||||
return self.get_param(ms_ctx_param.device_id)
|
||||
|
||||
@device_id.setter
|
||||
def device_id(self, device_id):
|
||||
if device_id < 0 or device_id > 4095:
|
||||
raise ValueError(
|
||||
"Device id must be in [0, 4095], but got {}".format(device_id))
|
||||
success = self._context_handle.set_device_id(device_id)
|
||||
if not success:
|
||||
raise RuntimeError("Device id set failed!!!")
|
||||
raise ValueError(f"Device id must be in [0, 4095], but got {device_id}")
|
||||
self.set_param(ms_ctx_param.device_id, device_id)
|
||||
|
||||
@property
|
||||
def max_call_depth(self):
|
||||
return self._context_handle.get_max_call_depth()
|
||||
return self.get_param(ms_ctx_param.max_call_depth)
|
||||
|
||||
@max_call_depth.setter
|
||||
def max_call_depth(self, max_call_depth):
|
||||
if max_call_depth <= 0:
|
||||
raise ValueError(
|
||||
"Max call depth must be greater than 0, but got {}".format(max_call_depth))
|
||||
self._context_handle.set_max_call_depth(max_call_depth)
|
||||
raise ValueError(f"Max call depth must be greater than 0, but got {max_call_depth}")
|
||||
self.set_param(ms_ctx_param.max_call_depth, max_call_depth)
|
||||
|
||||
@property
|
||||
def enable_auto_mixed_precision(self):
|
||||
return self._context_handle.get_auto_mixed_precision_flag()
|
||||
return self.get_param(ms_ctx_param.auto_mixed_precision_flag)
|
||||
|
||||
@enable_auto_mixed_precision.setter
|
||||
def enable_auto_mixed_precision(self, enable_auto_mixed_precision):
|
||||
self._context_handle.set_auto_mixed_precision_flag(
|
||||
enable_auto_mixed_precision)
|
||||
self.set_param(ms_ctx_param.auto_mixed_precision_flag, enable_auto_mixed_precision)
|
||||
|
||||
@property
|
||||
def enable_reduce_precision(self):
|
||||
return self._context_handle.get_enable_reduce_precision_flag()
|
||||
return self.get_param(ms_ctx_param.enable_reduce_precision_flag)
|
||||
|
||||
@enable_reduce_precision.setter
|
||||
def enable_reduce_precision(self, enable_reduce_precision):
|
||||
self._context_handle.set_enable_reduce_precision_flag(
|
||||
enable_reduce_precision)
|
||||
self.set_param(ms_ctx_param.enable_reduce_precision_flag, enable_reduce_precision)
|
||||
|
||||
@property
|
||||
def enable_dump(self):
|
||||
return self._context_handle.get_enable_dump()
|
||||
return self.get_param(ms_ctx_param.enable_dump)
|
||||
|
||||
@enable_dump.setter
|
||||
def enable_dump(self, enable_dump):
|
||||
self._context_handle.set_enable_dump(enable_dump)
|
||||
self.set_param(ms_ctx_param.enable_dump, enable_dump)
|
||||
|
||||
@property
|
||||
def save_dump_path(self):
|
||||
return self._context_handle.get_save_dump_path()
|
||||
return self.get_param(ms_ctx_param.save_dump_path)
|
||||
|
||||
@save_dump_path.setter
|
||||
def save_dump_path(self, save_dump_path):
|
||||
self._context_handle.set_save_dump_path(save_dump_path)
|
||||
self.set_param(ms_ctx_param.save_dump_path, save_dump_path)
|
||||
|
||||
@property
|
||||
def enable_profiling(self):
|
||||
return self._context_handle.get_enable_profiling()
|
||||
return self.get_param(ms_ctx_param.enable_profiling)
|
||||
|
||||
@enable_profiling.setter
|
||||
def enable_profiling(self, flag):
|
||||
self._context_handle.set_enable_profiling(flag)
|
||||
self.set_param(ms_ctx_param.enable_profiling, flag)
|
||||
|
||||
@property
|
||||
def profiling_options(self):
|
||||
return self._context_handle.get_profiling_options()
|
||||
return self.get_param(ms_ctx_param.profiling_options)
|
||||
|
||||
@profiling_options.setter
|
||||
def profiling_options(self, option):
|
||||
|
@ -298,15 +302,15 @@ class _Context:
|
|||
if option not in options:
|
||||
raise ValueError("Profiling options must be in 'training_trace' 'task_trace' "
|
||||
"'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'.")
|
||||
self._context_handle.set_profiling_options(option)
|
||||
self.set_param(ms_ctx_param.profiling_options, option)
|
||||
|
||||
@property
|
||||
def enable_graph_kernel(self):
|
||||
return self._context_handle.get_enable_graph_kernel()
|
||||
return self.get_param(ms_ctx_param.enable_graph_kernel)
|
||||
|
||||
@enable_graph_kernel.setter
|
||||
def enable_graph_kernel(self, graph_kernel_switch_):
|
||||
self._context_handle.set_enable_graph_kernel(graph_kernel_switch_)
|
||||
self.set_param(ms_ctx_param.enable_graph_kernel, graph_kernel_switch_)
|
||||
|
||||
@property
|
||||
def reserve_class_name_in_scope(self):
|
||||
|
@ -325,20 +329,14 @@ class _Context:
|
|||
@variable_memory_max_size.setter
|
||||
def variable_memory_max_size(self, variable_memory_max_size):
|
||||
if not check_input_format(variable_memory_max_size):
|
||||
raise ValueError(
|
||||
"Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
|
||||
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
|
||||
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
|
||||
raise ValueError(
|
||||
"Context param variable_memory_max_size should be less than 31GB.")
|
||||
variable_memory_max_size_ = variable_memory_max_size[:-
|
||||
2] + " * 1024 * 1024 * 1024"
|
||||
graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - \
|
||||
int(variable_memory_max_size[:-2])
|
||||
graph_memory_max_size_ = str(
|
||||
graph_memory_max_size) + " * 1024 * 1024 * 1024"
|
||||
self._context_handle.set_variable_memory_max_size(
|
||||
variable_memory_max_size_)
|
||||
self._context_handle.set_graph_memory_max_size(graph_memory_max_size_)
|
||||
raise ValueError("Context param variable_memory_max_size should be less than 31GB.")
|
||||
variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
|
||||
graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2])
|
||||
graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024"
|
||||
self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_)
|
||||
self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_)
|
||||
|
||||
@property
|
||||
def enable_ge(self):
|
||||
|
@ -355,15 +353,15 @@ class _Context:
|
|||
|
||||
@property
|
||||
def check_bprop(self):
|
||||
return self._context_handle.get_check_bprop_flag()
|
||||
return self.get_param(ms_ctx_param.check_bprop_flag)
|
||||
|
||||
@check_bprop.setter
|
||||
def check_bprop(self, check_bprop_flag):
|
||||
self._context_handle.set_check_bprop_flag(check_bprop_flag)
|
||||
self.set_param(ms_ctx_param.check_bprop_flag, check_bprop_flag)
|
||||
|
||||
@property
|
||||
def max_device_memory(self):
|
||||
return self._context_handle.get_max_device_memory()
|
||||
return self.get_param(ms_ctx_param.max_device_memory)
|
||||
|
||||
@max_device_memory.setter
|
||||
def max_device_memory(self, max_device_memory):
|
||||
|
@ -372,7 +370,7 @@ class _Context:
|
|||
max_device_memory_value = float(max_device_memory[:-2])
|
||||
if max_device_memory_value == 0:
|
||||
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
|
||||
self._context_handle.set_max_device_memory(max_device_memory_value)
|
||||
self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value)
|
||||
|
||||
@property
|
||||
def print_file_path(self):
|
||||
|
@ -392,15 +390,15 @@ class _Context:
|
|||
full_file_name = os.path.join(path, file_name)
|
||||
else:
|
||||
full_file_name = print_file_path
|
||||
self._context_handle.set_print_file_path(full_file_name)
|
||||
self.set_param(ms_ctx_param.print_file_path, full_file_name)
|
||||
|
||||
@property
|
||||
def enable_sparse(self):
|
||||
return self._context_handle.get_enable_sparse()
|
||||
return self.get_param(ms_ctx_param.enable_sparse)
|
||||
|
||||
@enable_sparse.setter
|
||||
def enable_sparse(self, enable_sparse):
|
||||
self._context_handle.set_enable_sparse(enable_sparse)
|
||||
self.set_param(ms_ctx_param.enable_sparse, enable_sparse)
|
||||
|
||||
def check_input_format(x):
|
||||
import re
|
||||
|
@ -486,8 +484,6 @@ def set_auto_parallel_context(**kwargs):
|
|||
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
||||
enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in
|
||||
data parallel training in the benefit of time and memory saving.
|
||||
max_call_depth(int): Specify the function call depth limit. Default: 1000.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
@ -501,7 +497,6 @@ def set_auto_parallel_context(**kwargs):
|
|||
>>> context.set_auto_parallel_context(parameter_broadcast=False)
|
||||
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
|
||||
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
|
||||
>>> context.set_auto_parallel_context(max_call_depth=80)
|
||||
"""
|
||||
_set_auto_parallel_context(**kwargs)
|
||||
|
||||
|
@ -603,6 +598,7 @@ def set_context(**kwargs):
|
|||
a file by default, and turn off printing to the screen. If the file already exists, add a timestamp
|
||||
suffix to the file.
|
||||
enable_sparse (bool): Whether to enable sparsity feature. Default: False.
|
||||
max_call_depth(int): Specify the function call depth limit. Default: 1000.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
|
@ -623,6 +619,7 @@ def set_context(**kwargs):
|
|||
>>> context.set_context(enable_profiling=True, profiling_options="training_trace")
|
||||
>>> context.set_context(max_device_memory="3.5GB")
|
||||
>>> context.set_context(print_file_path="print.pb")
|
||||
>>> context.set_context(max_call_depth=80)
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if not hasattr(_context(), key):
|
||||
|
|
|
@ -51,7 +51,7 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
|
|||
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->enable_sparse();
|
||||
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
|
||||
if (enable_sparse && dflt->isa<AbstractTensor>()) {
|
||||
auto dflt_tensor = dflt->cast<AbstractTensorPtr>();
|
||||
return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone());
|
||||
|
|
|
@ -232,7 +232,7 @@ std::string GetMaketupleNodeTarget(const CNodePtr &cnode) {
|
|||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string default_target = context_ptr->device_target();
|
||||
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
return default_target;
|
||||
}
|
||||
|
||||
|
@ -248,7 +248,7 @@ std::string GetTupleGetItemTarget(const CNodePtr &cnode, const PrimitivePtr &pri
|
|||
std::string GetCNodeTarget(const AnfNodePtr &node) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string default_target = context_ptr->device_target();
|
||||
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (!node->isa<CNode>()) {
|
||||
return default_target;
|
||||
}
|
||||
|
|
|
@ -652,7 +652,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
|
|||
new_func_graph->set_param_default_value(item.first, cloner[item.second]);
|
||||
}
|
||||
|
||||
if (MsContext::GetInstance()->is_multi_graph_sink()) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
|
||||
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
|
||||
new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ abstract::AbstractBasePtr MetaFuncGraph::ToAbstract() {
|
|||
FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->enable_sparse();
|
||||
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
|
||||
if (!enable_sparse) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -32,49 +32,50 @@ std::map<std::string, MsBackendPolicy> MsContext::policy_map_ = {{"ge", kMsBacke
|
|||
{"vm_prior", kMsBackendVmPrior}};
|
||||
|
||||
MsContext::MsContext(const std::string &policy, const std::string &target) {
|
||||
save_graphs_flag_ = false;
|
||||
save_graphs_path_ = ".";
|
||||
enable_dump_ = false;
|
||||
save_dump_path_ = ".";
|
||||
tsd_ref_ = 0;
|
||||
ge_ref_ = 0;
|
||||
is_multi_graph_sink_ = false;
|
||||
is_pynative_ge_init_ = false;
|
||||
enable_reduce_precision_ = true;
|
||||
set_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG, false);
|
||||
set_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH, ".");
|
||||
set_param<std::string>(MS_CTX_SAVE_DUMP_PATH, ".");
|
||||
set_param<uint32_t>(MS_CTX_TSD_REF, 0);
|
||||
set_param<uint32_t>(MS_CTX_GE_REF, 0);
|
||||
set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
|
||||
set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION, true);
|
||||
auto env_device = common::GetEnv("DEVICE_ID");
|
||||
if (!env_device.empty()) {
|
||||
device_id_ = UlongToUint(std::stoul(env_device.c_str()));
|
||||
uint32_t device_id = UlongToUint(std::stoul(env_device.c_str()));
|
||||
set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
|
||||
} else {
|
||||
device_id_ = 0;
|
||||
set_param<uint32_t>(MS_CTX_DEVICE_ID, 0);
|
||||
}
|
||||
max_call_depth_ = MAX_CALL_DEPTH_DEFAULT;
|
||||
backend_policy_ = policy_map_[policy];
|
||||
device_target_ = target;
|
||||
execution_mode_ = kPynativeMode;
|
||||
enable_task_sink_ = true;
|
||||
ir_fusion_flag_ = true;
|
||||
enable_hccl_ = false;
|
||||
set_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH, MAX_CALL_DEPTH_DEFAULT);
|
||||
set_param<std::string>(MS_CTX_DEVICE_TARGET, target);
|
||||
set_param<int>(MS_CTX_EXECUTION_MODE, kPynativeMode);
|
||||
set_param<bool>(MS_CTX_ENABLE_TASK_SINK, true);
|
||||
set_param<bool>(MS_CTX_IR_FUSION_FLAG, true);
|
||||
set_param<bool>(MS_CTX_ENABLE_HCCL, false);
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
enable_mem_reuse_ = false;
|
||||
set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, false);
|
||||
#else
|
||||
enable_mem_reuse_ = true;
|
||||
set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, true);
|
||||
#endif
|
||||
enable_gpu_summary_ = true;
|
||||
precompile_only_ = false;
|
||||
auto_mixed_precision_flag_ = false;
|
||||
enable_pynative_infer_ = false;
|
||||
enable_pynative_hook_ = false;
|
||||
enable_dynamic_mem_pool_ = true;
|
||||
graph_memory_max_size_ = "0";
|
||||
variable_memory_max_size_ = "0";
|
||||
enable_loop_sink_ = target == kAscendDevice || target == kDavinciDevice;
|
||||
profiling_mode_ = false;
|
||||
profiling_options_ = "training_trace";
|
||||
check_bprop_flag_ = false;
|
||||
max_device_memory_ = kDefaultMaxDeviceMemory;
|
||||
print_file_path_ = "";
|
||||
enable_graph_kernel_ = false;
|
||||
enable_sparse_ = false;
|
||||
set_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY, true);
|
||||
set_param<bool>(MS_CTX_PRECOMPILE_ONLY, false);
|
||||
set_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true);
|
||||
set_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE, "0");
|
||||
set_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE, "0");
|
||||
set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, target == kAscendDevice || target == kDavinciDevice);
|
||||
set_param<bool>(MS_CTX_ENABLE_PROFILING, false);
|
||||
set_param<std::string>(MS_CTX_PROFILING_OPTIONS, "training_trace");
|
||||
set_param<bool>(MS_CTX_CHECK_BPROP_FLAG, false);
|
||||
set_param<float>(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory);
|
||||
set_param<std::string>(MS_CTX_PRINT_FILE_PATH, "");
|
||||
set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
|
||||
|
||||
backend_policy_ = policy_map_[policy];
|
||||
}
|
||||
|
||||
std::shared_ptr<MsContext> MsContext::GetInstance() {
|
||||
|
@ -106,54 +107,4 @@ std::string MsContext::backend_policy() const {
|
|||
}
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
void MsContext::set_execution_mode(int execution_mode) {
|
||||
if (execution_mode != kGraphMode && execution_mode != kPynativeMode) {
|
||||
MS_LOG(EXCEPTION) << "The execution mode is invalid!";
|
||||
}
|
||||
execution_mode_ = execution_mode;
|
||||
}
|
||||
|
||||
bool MsContext::set_device_target(const std::string &target) {
|
||||
if (kTargetSet.find(target) == kTargetSet.end()) {
|
||||
MS_LOG(ERROR) << "invalid device target name: " << target;
|
||||
return false;
|
||||
}
|
||||
if (target == kDavinciDevice) {
|
||||
device_target_ = kAscendDevice;
|
||||
} else {
|
||||
device_target_ = target;
|
||||
}
|
||||
if (seter_) {
|
||||
seter_(device_target_);
|
||||
}
|
||||
MS_LOG(INFO) << "ms set context device target:" << target;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MsContext::set_device_id(uint32_t device_id) {
|
||||
device_id_ = device_id;
|
||||
MS_LOG(INFO) << "ms set context device id:" << device_id;
|
||||
return true;
|
||||
}
|
||||
|
||||
void MsContext::set_tsd_ref(const std::string &op) {
|
||||
if (op == "--") {
|
||||
tsd_ref_--;
|
||||
} else if (op == "++") {
|
||||
tsd_ref_++;
|
||||
} else {
|
||||
tsd_ref_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void MsContext::set_ge_ref(const std::string &op) {
|
||||
if (op == "--") {
|
||||
ge_ref_--;
|
||||
} else if (op == "++") {
|
||||
ge_ref_++;
|
||||
} else {
|
||||
ge_ref_ = 0;
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -49,6 +49,69 @@ const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice,
|
|||
// The default max available device memory is 1024GB.
|
||||
const float kDefaultMaxDeviceMemory = 1024;
|
||||
|
||||
// enum definition for MindSpore Context Parameter
|
||||
enum MsCtxParam : unsigned {
|
||||
// paramater of type bool
|
||||
MS_CTX_TYPE_BOOL_BEGIN,
|
||||
MS_CTX_AUTO_MIXED_PRECISION_FLAG = MS_CTX_TYPE_BOOL_BEGIN,
|
||||
MS_CTX_CHECK_BPROP_FLAG,
|
||||
MS_CTX_ENABLE_DUMP,
|
||||
MS_CTX_ENABLE_DYNAMIC_MEM_POOL,
|
||||
MS_CTX_ENABLE_GPU_SUMMARY,
|
||||
MS_CTX_ENABLE_GRAPH_KERNEL,
|
||||
MS_CTX_ENABLE_HCCL,
|
||||
MS_CTX_ENABLE_LOOP_SINK,
|
||||
MS_CTX_ENABLE_MEM_REUSE,
|
||||
MS_CTX_ENABLE_PYNATIVE_HOOK,
|
||||
MS_CTX_ENABLE_PYNATIVE_INFER,
|
||||
MS_CTX_ENABLE_REDUCE_PRECISION,
|
||||
MS_CTX_ENABLE_SPARSE,
|
||||
MS_CTX_ENABLE_TASK_SINK,
|
||||
MS_CTX_IR_FUSION_FLAG,
|
||||
MS_CTX_IS_MULTI_GRAPH_SINK,
|
||||
MS_CTX_IS_PYNATIVE_GE_INIT,
|
||||
MS_CTX_PRECOMPILE_ONLY,
|
||||
MS_CTX_ENABLE_PROFILING,
|
||||
MS_CTX_SAVE_GRAPHS_FLAG,
|
||||
MS_CTX_TYPE_BOOL_END,
|
||||
|
||||
// paramater of type int
|
||||
MS_CTX_TYPE_INT_BEGIN = MS_CTX_TYPE_BOOL_END,
|
||||
MS_CTX_EXECUTION_MODE = MS_CTX_TYPE_INT_BEGIN,
|
||||
MS_CTX_TYPE_INT_END,
|
||||
|
||||
// paramater of type uint32
|
||||
MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END,
|
||||
MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN,
|
||||
MS_CTX_GE_REF,
|
||||
MS_CTX_MAX_CALL_DEPTH,
|
||||
MS_CTX_TSD_REF,
|
||||
MS_CTX_TYPE_UINT32_END,
|
||||
|
||||
// paramater of type float
|
||||
MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END,
|
||||
MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN,
|
||||
MS_CTX_TYPE_FLOAT_END,
|
||||
|
||||
// paramater of type string
|
||||
MS_CTX_TYPE_STRING_BEGIN = MS_CTX_TYPE_FLOAT_END,
|
||||
MS_CTX_DEVICE_TARGET = MS_CTX_TYPE_STRING_BEGIN,
|
||||
MS_CTX_GRAPH_MEMORY_MAX_SIZE,
|
||||
MS_CTX_PRINT_FILE_PATH,
|
||||
MS_CTX_PROFILING_OPTIONS,
|
||||
MS_CTX_SAVE_DUMP_PATH,
|
||||
MS_CTX_SAVE_GRAPHS_PATH,
|
||||
MS_CTX_VARIABLE_MEMORY_MAX_SIZE,
|
||||
MS_CTX_TYPE_STRING_END,
|
||||
|
||||
// parameter numbers of each type
|
||||
NUM_BOOL_PARAMS = MS_CTX_TYPE_BOOL_END - MS_CTX_TYPE_BOOL_BEGIN,
|
||||
NUM_INT_PARAMS = MS_CTX_TYPE_INT_END - MS_CTX_TYPE_INT_BEGIN,
|
||||
NUM_UINT32_PARAMS = MS_CTX_TYPE_UINT32_END - MS_CTX_TYPE_UINT32_BEGIN,
|
||||
NUM_FLOAT_PARAMS = MS_CTX_TYPE_FLOAT_END - MS_CTX_TYPE_FLOAT_BEGIN,
|
||||
NUM_STRING_PARAMS = MS_CTX_TYPE_STRING_END - MS_CTX_TYPE_STRING_BEGIN
|
||||
};
|
||||
|
||||
class MsContext {
|
||||
public:
|
||||
MsContext(const std::string &backend_policy, const std::string &target);
|
||||
|
@ -62,156 +125,113 @@ class MsContext {
|
|||
std::string backend_policy() const;
|
||||
bool set_backend_policy(const std::string &policy);
|
||||
|
||||
int execution_mode() const { return execution_mode_; }
|
||||
void set_execution_mode(int execution_mode);
|
||||
|
||||
bool enable_pynative_infer() const { return enable_pynative_infer_; }
|
||||
void set_enable_pynative_infer(bool enable_pynative_infer) { enable_pynative_infer_ = enable_pynative_infer; }
|
||||
|
||||
bool enable_pynative_hook() const { return enable_pynative_hook_; }
|
||||
void set_enable_pynative_hook(bool enable_pynative_hook) { enable_pynative_hook_ = enable_pynative_hook; }
|
||||
|
||||
bool enable_task_sink() const { return enable_task_sink_; }
|
||||
|
||||
void set_precompile_only(bool precompile_only) { precompile_only_ = precompile_only; }
|
||||
bool precompile_only() const { return precompile_only_; }
|
||||
|
||||
std::string device_target() const { return device_target_; }
|
||||
bool set_device_target(const std::string &target);
|
||||
|
||||
uint32_t device_id() const { return device_id_; }
|
||||
bool set_device_id(uint32_t device_id);
|
||||
|
||||
// uint32_t max_call_depth_
|
||||
uint32_t max_call_depth() const { return max_call_depth_; }
|
||||
inline bool set_max_call_depth(uint32_t max_call_depth) {
|
||||
max_call_depth_ = max_call_depth;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool save_graphs_flag() const { return save_graphs_flag_; }
|
||||
void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; }
|
||||
|
||||
std::string save_graphs_path() const { return save_graphs_path_; }
|
||||
void set_save_graphs_path(const std::string &save_paths) { save_graphs_path_ = save_paths; }
|
||||
|
||||
bool IsGeInited() { return ge_ref_ > 0; }
|
||||
void set_enable_hccl(bool enable_hccl) { enable_hccl_ = enable_hccl; }
|
||||
bool enable_hccl() const { return enable_hccl_; }
|
||||
bool ir_fusion_flag() const { return ir_fusion_flag_; }
|
||||
bool loop_sink_flag() const { return enable_loop_sink_; }
|
||||
void set_loop_sink_flag(bool enable_loop_sink) { enable_loop_sink_ = enable_loop_sink; }
|
||||
void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; }
|
||||
bool enable_mem_reuse() const { return enable_mem_reuse_; }
|
||||
|
||||
void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; }
|
||||
bool enable_gpu_summary() const { return enable_gpu_summary_; }
|
||||
|
||||
void set_auto_mixed_precision_flag(bool auto_mixed_precision_flag) {
|
||||
auto_mixed_precision_flag_ = auto_mixed_precision_flag;
|
||||
}
|
||||
bool auto_mixed_precision_flag() const { return auto_mixed_precision_flag_; }
|
||||
|
||||
void set_enable_reduce_precision(bool flag) { enable_reduce_precision_ = flag; }
|
||||
bool enable_reduce_precision() const { return enable_reduce_precision_; }
|
||||
|
||||
void set_enable_dump(bool flag) { enable_dump_ = flag; }
|
||||
bool enable_dump() const { return enable_dump_; }
|
||||
|
||||
void set_save_dump_path(const std::string &path) { save_dump_path_ = path; }
|
||||
std::string save_dump_path() const { return save_dump_path_; }
|
||||
|
||||
bool IsTsdOpened() const { return tsd_ref_ > 0; }
|
||||
void set_tsd_ref(const std::string &op);
|
||||
uint32_t tsd_ref() const { return tsd_ref_; }
|
||||
|
||||
void set_ge_ref(const std::string &op);
|
||||
uint32_t ge_ref() const { return ge_ref_; }
|
||||
|
||||
bool is_pynative_ge_init() { return is_pynative_ge_init_; }
|
||||
void set_pynative_ge_init(bool flag) { is_pynative_ge_init_ = flag; }
|
||||
|
||||
bool is_multi_graph_sink() const { return is_multi_graph_sink_; }
|
||||
void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; }
|
||||
|
||||
void set_enable_dynamic_mem_pool(bool enable_dynamic_mem_pool) { enable_dynamic_mem_pool_ = enable_dynamic_mem_pool; }
|
||||
bool enable_dynamic_mem_pool() const { return enable_dynamic_mem_pool_; }
|
||||
|
||||
void set_graph_memory_max_size(const std::string &graph_memory_max_size) {
|
||||
graph_memory_max_size_ = graph_memory_max_size;
|
||||
}
|
||||
|
||||
void set_variable_memory_max_size(const std::string &variable_memory_max_size) {
|
||||
variable_memory_max_size_ = variable_memory_max_size;
|
||||
}
|
||||
|
||||
const std::string &variable_memory_max_size() const { return variable_memory_max_size_; }
|
||||
|
||||
const std::string &graph_memory_max_size() const { return graph_memory_max_size_; }
|
||||
|
||||
void set_enable_profiling(bool flag) { profiling_mode_ = flag; }
|
||||
bool enable_profiling() const { return profiling_mode_; }
|
||||
|
||||
void set_profiling_options(const std::string &options) { profiling_options_ = options; }
|
||||
std::string profiling_options() const { return profiling_options_; }
|
||||
bool check_bprop_flag() const { return check_bprop_flag_; }
|
||||
void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; }
|
||||
void set_print_file_path(const std::string &file) { print_file_path_ = file; }
|
||||
const std::string &print_file_path() const { return print_file_path_; }
|
||||
|
||||
float max_device_memory() const { return max_device_memory_; }
|
||||
void set_max_device_memory(float max_device_memory) { max_device_memory_ = max_device_memory; }
|
||||
|
||||
void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; }
|
||||
bool enable_graph_kernel() const { return enable_graph_kernel_; }
|
||||
|
||||
bool enable_sparse() const { return enable_sparse_; }
|
||||
void set_enable_sparse(bool enable_sparse) { enable_sparse_ = enable_sparse; }
|
||||
static void device_seter(DeviceSeter device) { seter_ = device; }
|
||||
static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; }
|
||||
|
||||
std::thread tdt_print_;
|
||||
|
||||
template <typename T>
|
||||
void set_param(MsCtxParam param, const T &value) {
|
||||
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T &get_param(MsCtxParam param) const {
|
||||
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void increase_param(MsCtxParam param) {
|
||||
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void decrease_param(MsCtxParam param) {
|
||||
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
|
||||
}
|
||||
|
||||
private:
|
||||
inline static DeviceSeter seter_ = nullptr;
|
||||
inline static DeviceTypeSeter device_type_seter_ = nullptr;
|
||||
static std::shared_ptr<MsContext> inst_context_;
|
||||
static std::map<std::string, MsBackendPolicy> policy_map_;
|
||||
|
||||
bool bool_params_[MsCtxParam::NUM_BOOL_PARAMS];
|
||||
int int_params_[MsCtxParam::NUM_INT_PARAMS];
|
||||
uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS];
|
||||
float float_params_[MsCtxParam::NUM_FLOAT_PARAMS];
|
||||
std::string string_params_[MsCtxParam::NUM_STRING_PARAMS];
|
||||
|
||||
MsBackendPolicy backend_policy_;
|
||||
std::string device_target_;
|
||||
uint32_t device_id_;
|
||||
uint32_t max_call_depth_;
|
||||
int execution_mode_;
|
||||
bool enable_pynative_infer_;
|
||||
bool enable_pynative_hook_;
|
||||
bool save_graphs_flag_;
|
||||
std::string save_graphs_path_;
|
||||
uint32_t tsd_ref_;
|
||||
uint32_t ge_ref_;
|
||||
bool enable_task_sink_;
|
||||
bool enable_hccl_;
|
||||
bool precompile_only_;
|
||||
bool ir_fusion_flag_;
|
||||
bool auto_mixed_precision_flag_;
|
||||
bool enable_reduce_precision_;
|
||||
bool enable_loop_sink_;
|
||||
bool enable_mem_reuse_;
|
||||
bool enable_gpu_summary_;
|
||||
bool enable_dump_;
|
||||
std::string save_dump_path_;
|
||||
bool is_multi_graph_sink_;
|
||||
bool is_pynative_ge_init_;
|
||||
bool enable_dynamic_mem_pool_;
|
||||
std::string graph_memory_max_size_;
|
||||
std::string variable_memory_max_size_;
|
||||
bool profiling_mode_;
|
||||
std::string profiling_options_;
|
||||
bool check_bprop_flag_;
|
||||
float max_device_memory_;
|
||||
std::string print_file_path_;
|
||||
bool enable_graph_kernel_;
|
||||
bool enable_sparse_;
|
||||
};
|
||||
|
||||
// set method implementation for type bool/int/uint32_t/float/std::string
|
||||
template <>
|
||||
inline void MsContext::set_param<bool>(MsCtxParam param, const bool &value) {
|
||||
bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN] = value;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void MsContext::set_param<int>(MsCtxParam param, const int &value) {
|
||||
int_params_[param - MS_CTX_TYPE_INT_BEGIN] = value;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void MsContext::set_param<uint32_t>(MsCtxParam param, const uint32_t &value) {
|
||||
uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN] = value;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void MsContext::set_param<float>(MsCtxParam param, const float &value) {
|
||||
float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN] = value;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void MsContext::set_param<std::string>(MsCtxParam param, const std::string &value) {
|
||||
if (seter_ != nullptr && param == MS_CTX_DEVICE_TARGET) {
|
||||
MS_LOG(INFO) << "ms set context device target:" << value;
|
||||
seter_(value);
|
||||
}
|
||||
string_params_[param - MS_CTX_TYPE_STRING_BEGIN] = value;
|
||||
}
|
||||
|
||||
// get method implementation for type bool/int/uint32_t/float/std::string
|
||||
template <>
|
||||
inline const bool &MsContext::get_param<bool>(MsCtxParam param) const {
|
||||
return bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const int &MsContext::get_param<int>(MsCtxParam param) const {
|
||||
return int_params_[param - MS_CTX_TYPE_INT_BEGIN];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const uint32_t &MsContext::get_param<uint32_t>(MsCtxParam param) const {
|
||||
return uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const float &MsContext::get_param<float>(MsCtxParam param) const {
|
||||
return float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const std::string &MsContext::get_param<std::string>(MsCtxParam param) const {
|
||||
return string_params_[param - MS_CTX_TYPE_STRING_BEGIN];
|
||||
}
|
||||
|
||||
// increate method implementation for type uint32_t
|
||||
template <>
|
||||
inline void MsContext::increase_param<uint32_t>(MsCtxParam param) {
|
||||
uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]++;
|
||||
}
|
||||
|
||||
// decreate method implementation for type uint32_t
|
||||
template <>
|
||||
inline void MsContext::decrease_param<uint32_t>(MsCtxParam param) {
|
||||
uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]--;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_UTILS_MS_CONTEXT_H_
|
||||
|
|
|
@ -42,7 +42,7 @@ class TestOptLib : public UT::Common {
|
|||
parse::data_converter::ClearObjectCache();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
ms_context->set_execution_mode(kGraphMode);
|
||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
}
|
||||
FuncGraphPtr RunTransform(FuncGraphPtr gbefore, const SubstitutionList &transform) {
|
||||
equiv_node.clear();
|
||||
|
|
|
@ -112,7 +112,7 @@ TEST_F(TestHWInsertTransOp, test_insert_trans_op_for_single_output) {
|
|||
*/
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
ms_context->set_execution_mode(kGraphMode);
|
||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
auto fg = GetSingleOutputGraph("test_insert_trans_op_for_single_output", "before", "NC1HWC0");
|
||||
// Do insert_trans_op_ pass of hardware opt
|
||||
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
|
|
|
@ -112,7 +112,7 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect {
|
|||
TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_single_output) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
ms_context->set_execution_mode(kGraphMode);
|
||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
auto kg = GetSingleOutputGraph("test_remove_internal_output_trans_op_for_single_output", "before");
|
||||
// insert trans op for output
|
||||
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
|
|
|
@ -104,7 +104,7 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
|
|||
*/
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
ms_context->set_execution_mode(kGraphMode);
|
||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transdata_split_fraz_nchw", "before");
|
||||
std::vector<int> shp{2, 4, 8, 16};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
|
|
|
@ -83,7 +83,7 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
|
|||
*/
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
ms_context->set_execution_mode(kGraphMode);
|
||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_transdata_fusion", "before");
|
||||
std::vector<int> shp{2, 4, 8, 16};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
|
|
|
@ -76,7 +76,7 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) {
|
|||
*/
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
ms_context->set_execution_mode(kGraphMode);
|
||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_eliminate_5to4_4to5", "before");
|
||||
// Renormalize func_graph to infer and set shape and type information.
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
|
|
|
@ -71,15 +71,15 @@ OpExecInfoPtr ConstructOpExecInfo() {
|
|||
TEST_F(TestPynativeExecute, TestCreateContext) {
|
||||
auto ctx3 = MsContext::GetInstance();
|
||||
ASSERT_EQ(ctx3->backend_policy(), "vm");
|
||||
ASSERT_EQ(ctx3->device_target(), "CPU");
|
||||
ASSERT_EQ(ctx3->get_param<std::string>(MS_CTX_DEVICE_TARGET), "CPU");
|
||||
|
||||
ctx3->set_backend_policy("ge_only");
|
||||
ctx3->set_device_target("GPU");
|
||||
ctx3->set_param<std::string>(MS_CTX_DEVICE_TARGET, "GPU");
|
||||
auto ctx4 = MsContext::GetInstance();
|
||||
|
||||
ASSERT_EQ(ctx3.get(), ctx4.get());
|
||||
ASSERT_EQ(ctx4->backend_policy(), "ge_only");
|
||||
ASSERT_EQ(ctx4->device_target(), "GPU");
|
||||
ASSERT_EQ(ctx4->get_param<std::string>(MS_CTX_DEVICE_TARGET), "GPU");
|
||||
}
|
||||
|
||||
TEST_F(TestPynativeExecute, TestDefaultContext) {
|
||||
|
|
Loading…
Reference in New Issue