From a6dc9a0a070fa12d1a2874e73fb4f6ac4f2335c4 Mon Sep 17 00:00:00 2001 From: yelihua Date: Mon, 16 Aug 2021 13:24:05 +0800 Subject: [PATCH] get rank id when set hccl env for single card train --- .../ccsrc/backend/session/ascend_session.cc | 7 +--- .../ccsrc/backend/session/session_basic.cc | 6 ++++ .../ccsrc/debug/data_dump/dump_json_parser.cc | 5 +-- .../runtime/device/ascend/dump/data_dumper.cc | 6 ++-- tests/st/dump/test_data_dump.py | 33 +++++++++++++++++++ 5 files changed, 47 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 74ef457e603..93ef2e90897 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -847,12 +847,7 @@ void AscendSession::InitRuntimeResource() { if (!runtime_instance->Init()) { MS_LOG(EXCEPTION) << "Kernel runtime init error."; } - auto env_table_file = common::GetEnv("RANK_TABLE_FILE"); - auto env_rank_id = common::GetEnv("RANK_ID"); - if (!(env_table_file.empty() || env_rank_id.empty())) { - // get actual rank id if it's distribution training case. - rank_id_ = GetRankId(); - } + rank_id_ = GetRankId(); DumpInit(rank_id_); MS_LOG(INFO) << "Finish!"; } diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 22cd874e937..8b8da0cbf10 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -2683,6 +2683,11 @@ uint32_t GetRankId() { uint32_t rank_id = 0; auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); + auto env_rank_id = common::GetEnv("RANK_ID"); + if (!ms_context->get_param(MS_CTX_ENABLE_HCCL) || env_rank_id.empty()) { + MS_LOG(INFO) << "HCCL not enabled, use 0 as default rank id."; + return rank_id; + } std::string world_group; std::string backend = ms_context->get_param(MS_CTX_DEVICE_TARGET); if (backend == kAscendDevice) { @@ -2691,6 +2696,7 @@ uint32_t GetRankId() { world_group = kNcclWorldGroup; } else { MS_LOG(ERROR) << "Invalid backend: " << backend; + return rank_id; } if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) { MS_LOG(INFO) << "Failed to get rank id."; diff --git a/mindspore/ccsrc/debug/data_dump/dump_json_parser.cc b/mindspore/ccsrc/debug/data_dump/dump_json_parser.cc index 094ca4755dc..a667d9e43a9 100644 --- a/mindspore/ccsrc/debug/data_dump/dump_json_parser.cc +++ b/mindspore/ccsrc/debug/data_dump/dump_json_parser.cc @@ -550,9 +550,10 @@ std::string DumpJsonParser::GetOpOverflowBinPath(uint32_t graph_id) const { bin_path.append("rank_"); uint32_t rank_id = 0; - auto env_table_file = common::GetEnv("RANK_TABLE_FILE"); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); auto env_rank_id = common::GetEnv("RANK_ID"); - if (!(env_table_file.empty() || env_rank_id.empty())) { + if (ms_context->get_param(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) { // get actual rank id if it's distribution training case. if (!CommManager::GetInstance().GetRankID(kHcclWorldGroup, &rank_id)) { MS_LOG(INFO) << "Failed to get rank id."; diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc index aefcb8cc553..f35d78d336f 100644 --- a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc +++ b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc @@ -133,9 +133,11 @@ void DataDumper::SetOpMappingInfo(NotNull dump_inf } uint32_t graph_id = kernel_graph_->graph_id(); uint32_t rank_id = 0; - auto env_table_file = common::GetEnv("RANK_TABLE_FILE"); + + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); auto env_rank_id = common::GetEnv("RANK_ID"); - if (!(env_table_file.empty() || env_rank_id.empty())) { + if (ms_context->get_param(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) { // get actual rank id if it's distribution training case. if (!CommManager::GetInstance().GetRankID(kHcclWorldGroup, &rank_id)) { MS_LOG(INFO) << "Failed to get rank id."; diff --git a/tests/st/dump/test_data_dump.py b/tests/st/dump/test_data_dump.py index f1b637084d7..2b280d86cf8 100644 --- a/tests/st/dump/test_data_dump.py +++ b/tests/st/dump/test_data_dump.py @@ -119,6 +119,17 @@ def test_e2e_dump(): run_e2e_dump() +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_e2e_dump_with_hccl_env(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + os.environ["RANK_TABLE_FILE"] = "invalid_file.json" + os.environ["RANK_ID"] = "4" + run_e2e_dump() + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -126,6 +137,17 @@ def test_cpu_e2e_dump(): context.set_context(mode=context.GRAPH_MODE, device_target="CPU") run_e2e_dump() + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cpu_e2e_dump_with_hccl_set(): + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + os.environ["RANK_TABLE_FILE"] = "invalid_file.json" + os.environ["RANK_ID"] = "4" + run_e2e_dump() + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -133,6 +155,17 @@ def test_gpu_e2e_dump(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") run_e2e_dump() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gpu_e2e_dump_with_hccl_set(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + os.environ["RANK_TABLE_FILE"] = "invalid_file.json" + os.environ["RANK_ID"] = "4" + run_e2e_dump() + + class ReluReduceMeanDenseRelu(Cell): def __init__(self, kernel, bias, in_channel, num_class): super().__init__()