!2552 check whether mpi instance is null

Merge pull request !2552 from chenjianping/host_reduce
This commit is contained in:
mindspore-ci-bot 2020-06-24 17:03:51 +08:00 committed by Gitee
commit 927278be44
5 changed files with 17 additions and 11 deletions

View File

@ -60,7 +60,9 @@ std::string GetRankId() {
auto mpi_config_ptr = MpiConfig::GetInstance();
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
if (mpi_config_ptr->enable_mpi()) {
int rank_id = device::cpu::MPIAdapter::Instance().GetRankId();
auto mpi_instance = device::cpu::MPIAdapter::Instance();
MS_EXCEPTION_IF_NULL(mpi_instance);
int rank_id = mpi_instance->GetRankId();
const char *offset = std::getenv("RANK_OFFSET");
if (offset != nullptr) {
try {

View File

@ -46,8 +46,9 @@ bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto input_data_num = inputs[0]->size / sizeof(float);
return device::cpu::MPIAdapter::Instance()->AllGather(input_addr, output_addr, ranks_group_, input_data_num);
auto mpi_instance = device::cpu::MPIAdapter::Instance();
MS_EXCEPTION_IF_NULL(mpi_instance);
return mpi_instance->AllGather(input_addr, output_addr, ranks_group_, input_data_num);
}
} // namespace kernel
} // namespace mindspore

View File

@ -50,9 +50,11 @@ bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector<kernel::AddressP
const std::vector<int> &rank_group = {0, 1, 2, 3, 4, 5, 6, 7};
size_t input_split_lens = input_size / split_num_ / sizeof(float_t);
size_t output_split_lens = output_size / split_num_ / sizeof(float_t);
auto mpi_instance = device::cpu::MPIAdapter::Instance();
MS_EXCEPTION_IF_NULL(mpi_instance);
for (int i = 0; i < split_num_; i++) {
device::cpu::MPIAdapter::Instance()->AllGather(input_addr + i * input_split_lens,
output_addr + i * output_split_lens, rank_group, input_split_lens);
mpi_instance->AllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group,
input_split_lens);
}
#if defined(_WIN32) || defined(_WIN64)
auto end_time = std::chrono::steady_clock::now();

View File

@ -104,10 +104,11 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float);
size_t reduce_scatter_out_lens = one_split_lens / 8;
const std::vector<int> &group = {0, 1, 2, 3, 4, 5, 6, 7};
auto mpi_instance = device::cpu::MPIAdapter::Instance();
MS_EXCEPTION_IF_NULL(mpi_instance);
for (int i = 0; i < split_num_; i++) {
device::cpu::MPIAdapter::Instance()->ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens,
output_addr + i * reduce_scatter_out_lens, group,
one_split_lens / 8, "sum");
mpi_instance->ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens,
output_addr + i * reduce_scatter_out_lens, group, one_split_lens / 8, "sum");
}
}
#endif

View File

@ -46,9 +46,9 @@ bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto output_data_num = outputs[0]->size / sizeof(float);
return device::cpu::MPIAdapter::Instance()->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num,
op_type_);
auto mpi_instance = device::cpu::MPIAdapter::Instance();
MS_EXCEPTION_IF_NULL(mpi_instance);
return mpi_instance->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, op_type_);
}
} // namespace kernel
} // namespace mindspore