forked from mindspore-Ecosystem/mindspore
get rank id when set hccl env for single card train
This commit is contained in:
parent
fca1cb34c8
commit
a6dc9a0a07
|
@ -847,12 +847,7 @@ void AscendSession::InitRuntimeResource() {
|
|||
if (!runtime_instance->Init()) {
|
||||
MS_LOG(EXCEPTION) << "Kernel runtime init error.";
|
||||
}
|
||||
auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
|
||||
auto env_rank_id = common::GetEnv("RANK_ID");
|
||||
if (!(env_table_file.empty() || env_rank_id.empty())) {
|
||||
// get actual rank id if it's distribution training case.
|
||||
rank_id_ = GetRankId();
|
||||
}
|
||||
DumpInit(rank_id_);
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
|
|
@ -2683,6 +2683,11 @@ uint32_t GetRankId() {
|
|||
uint32_t rank_id = 0;
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto env_rank_id = common::GetEnv("RANK_ID");
|
||||
if (!ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) || env_rank_id.empty()) {
|
||||
MS_LOG(INFO) << "HCCL not enabled, use 0 as default rank id.";
|
||||
return rank_id;
|
||||
}
|
||||
std::string world_group;
|
||||
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (backend == kAscendDevice) {
|
||||
|
@ -2691,6 +2696,7 @@ uint32_t GetRankId() {
|
|||
world_group = kNcclWorldGroup;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid backend: " << backend;
|
||||
return rank_id;
|
||||
}
|
||||
if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
|
||||
MS_LOG(INFO) << "Failed to get rank id.";
|
||||
|
|
|
@ -550,9 +550,10 @@ std::string DumpJsonParser::GetOpOverflowBinPath(uint32_t graph_id) const {
|
|||
bin_path.append("rank_");
|
||||
|
||||
uint32_t rank_id = 0;
|
||||
auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto env_rank_id = common::GetEnv("RANK_ID");
|
||||
if (!(env_table_file.empty() || env_rank_id.empty())) {
|
||||
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) {
|
||||
// get actual rank id if it's distribution training case.
|
||||
if (!CommManager::GetInstance().GetRankID(kHcclWorldGroup, &rank_id)) {
|
||||
MS_LOG(INFO) << "Failed to get rank id.";
|
||||
|
|
|
@ -133,9 +133,11 @@ void DataDumper::SetOpMappingInfo(NotNull<aicpu::dump::OpMappingInfo *> dump_inf
|
|||
}
|
||||
uint32_t graph_id = kernel_graph_->graph_id();
|
||||
uint32_t rank_id = 0;
|
||||
auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
|
||||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto env_rank_id = common::GetEnv("RANK_ID");
|
||||
if (!(env_table_file.empty() || env_rank_id.empty())) {
|
||||
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) {
|
||||
// get actual rank id if it's distribution training case.
|
||||
if (!CommManager::GetInstance().GetRankID(kHcclWorldGroup, &rank_id)) {
|
||||
MS_LOG(INFO) << "Failed to get rank id.";
|
||||
|
|
|
@ -119,6 +119,17 @@ def test_e2e_dump():
|
|||
run_e2e_dump()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_e2e_dump_with_hccl_env():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
os.environ["RANK_TABLE_FILE"] = "invalid_file.json"
|
||||
os.environ["RANK_ID"] = "4"
|
||||
run_e2e_dump()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -126,6 +137,17 @@ def test_cpu_e2e_dump():
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
run_e2e_dump()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_cpu_e2e_dump_with_hccl_set():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
os.environ["RANK_TABLE_FILE"] = "invalid_file.json"
|
||||
os.environ["RANK_ID"] = "4"
|
||||
run_e2e_dump()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -133,6 +155,17 @@ def test_gpu_e2e_dump():
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
run_e2e_dump()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gpu_e2e_dump_with_hccl_set():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
os.environ["RANK_TABLE_FILE"] = "invalid_file.json"
|
||||
os.environ["RANK_ID"] = "4"
|
||||
run_e2e_dump()
|
||||
|
||||
|
||||
class ReluReduceMeanDenseRelu(Cell):
|
||||
def __init__(self, kernel, bias, in_channel, num_class):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue