get rank id when set hccl env for single card train

This commit is contained in:
yelihua 2021-08-16 13:24:05 +08:00
parent fca1cb34c8
commit a6dc9a0a07
5 changed files with 47 additions and 10 deletions

View File

@ -847,12 +847,7 @@ void AscendSession::InitRuntimeResource() {
if (!runtime_instance->Init()) {
MS_LOG(EXCEPTION) << "Kernel runtime init error.";
}
auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
auto env_rank_id = common::GetEnv("RANK_ID");
if (!(env_table_file.empty() || env_rank_id.empty())) {
// get actual rank id if it's distribution training case.
rank_id_ = GetRankId();
}
DumpInit(rank_id_);
MS_LOG(INFO) << "Finish!";
}

View File

@ -2683,6 +2683,11 @@ uint32_t GetRankId() {
uint32_t rank_id = 0;
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto env_rank_id = common::GetEnv("RANK_ID");
if (!ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) || env_rank_id.empty()) {
MS_LOG(INFO) << "HCCL not enabled, use 0 as default rank id.";
return rank_id;
}
std::string world_group;
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (backend == kAscendDevice) {
@ -2691,6 +2696,7 @@ uint32_t GetRankId() {
world_group = kNcclWorldGroup;
} else {
MS_LOG(ERROR) << "Invalid backend: " << backend;
return rank_id;
}
if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
MS_LOG(INFO) << "Failed to get rank id.";

View File

@ -550,9 +550,10 @@ std::string DumpJsonParser::GetOpOverflowBinPath(uint32_t graph_id) const {
bin_path.append("rank_");
uint32_t rank_id = 0;
auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto env_rank_id = common::GetEnv("RANK_ID");
if (!(env_table_file.empty() || env_rank_id.empty())) {
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) {
// get actual rank id if it's distribution training case.
if (!CommManager::GetInstance().GetRankID(kHcclWorldGroup, &rank_id)) {
MS_LOG(INFO) << "Failed to get rank id.";

View File

@ -133,9 +133,11 @@ void DataDumper::SetOpMappingInfo(NotNull<aicpu::dump::OpMappingInfo *> dump_inf
}
uint32_t graph_id = kernel_graph_->graph_id();
uint32_t rank_id = 0;
auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto env_rank_id = common::GetEnv("RANK_ID");
if (!(env_table_file.empty() || env_rank_id.empty())) {
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) {
// get actual rank id if it's distribution training case.
if (!CommManager::GetInstance().GetRankID(kHcclWorldGroup, &rank_id)) {
MS_LOG(INFO) << "Failed to get rank id.";

View File

@ -119,6 +119,17 @@ def test_e2e_dump():
run_e2e_dump()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_e2e_dump_with_hccl_env():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
os.environ["RANK_TABLE_FILE"] = "invalid_file.json"
os.environ["RANK_ID"] = "4"
run_e2e_dump()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@ -126,6 +137,17 @@ def test_cpu_e2e_dump():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
run_e2e_dump()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cpu_e2e_dump_with_hccl_set():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
os.environ["RANK_TABLE_FILE"] = "invalid_file.json"
os.environ["RANK_ID"] = "4"
run_e2e_dump()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -133,6 +155,17 @@ def test_gpu_e2e_dump():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
run_e2e_dump()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_e2e_dump_with_hccl_set():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
os.environ["RANK_TABLE_FILE"] = "invalid_file.json"
os.environ["RANK_ID"] = "4"
run_e2e_dump()
class ReluReduceMeanDenseRelu(Cell):
def __init__(self, kernel, bias, in_channel, num_class):
super().__init__()