!2552 check whether mpi instance is null

Merge pull request !2552 from chenjianping/host_reduce
This commit is contained in:
mindspore-ci-bot 2020-06-24 17:03:51 +08:00 committed by Gitee
commit 927278be44
5 changed files with 17 additions and 11 deletions

View File

@ -60,7 +60,9 @@ std::string GetRankId() {
auto mpi_config_ptr = MpiConfig::GetInstance(); auto mpi_config_ptr = MpiConfig::GetInstance();
MS_EXCEPTION_IF_NULL(mpi_config_ptr); MS_EXCEPTION_IF_NULL(mpi_config_ptr);
if (mpi_config_ptr->enable_mpi()) { if (mpi_config_ptr->enable_mpi()) {
int rank_id = device::cpu::MPIAdapter::Instance().GetRankId(); auto mpi_instance = device::cpu::MPIAdapter::Instance();
MS_EXCEPTION_IF_NULL(mpi_instance);
int rank_id = mpi_instance->GetRankId();
const char *offset = std::getenv("RANK_OFFSET"); const char *offset = std::getenv("RANK_OFFSET");
if (offset != nullptr) { if (offset != nullptr) {
try { try {

View File

@ -46,8 +46,9 @@ bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto input_data_num = inputs[0]->size / sizeof(float); auto input_data_num = inputs[0]->size / sizeof(float);
auto mpi_instance = device::cpu::MPIAdapter::Instance();
return device::cpu::MPIAdapter::Instance()->AllGather(input_addr, output_addr, ranks_group_, input_data_num); MS_EXCEPTION_IF_NULL(mpi_instance);
return mpi_instance->AllGather(input_addr, output_addr, ranks_group_, input_data_num);
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -50,9 +50,11 @@ bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector<kernel::AddressP
const std::vector<int> &rank_group = {0, 1, 2, 3, 4, 5, 6, 7}; const std::vector<int> &rank_group = {0, 1, 2, 3, 4, 5, 6, 7};
size_t input_split_lens = input_size / split_num_ / sizeof(float_t); size_t input_split_lens = input_size / split_num_ / sizeof(float_t);
size_t output_split_lens = output_size / split_num_ / sizeof(float_t); size_t output_split_lens = output_size / split_num_ / sizeof(float_t);
auto mpi_instance = device::cpu::MPIAdapter::Instance();
MS_EXCEPTION_IF_NULL(mpi_instance);
for (int i = 0; i < split_num_; i++) { for (int i = 0; i < split_num_; i++) {
device::cpu::MPIAdapter::Instance()->AllGather(input_addr + i * input_split_lens, mpi_instance->AllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group,
output_addr + i * output_split_lens, rank_group, input_split_lens); input_split_lens);
} }
#if defined(_WIN32) || defined(_WIN64) #if defined(_WIN32) || defined(_WIN64)
auto end_time = std::chrono::steady_clock::now(); auto end_time = std::chrono::steady_clock::now();

View File

@ -104,10 +104,11 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float); size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float);
size_t reduce_scatter_out_lens = one_split_lens / 8; size_t reduce_scatter_out_lens = one_split_lens / 8;
const std::vector<int> &group = {0, 1, 2, 3, 4, 5, 6, 7}; const std::vector<int> &group = {0, 1, 2, 3, 4, 5, 6, 7};
auto mpi_instance = device::cpu::MPIAdapter::Instance();
MS_EXCEPTION_IF_NULL(mpi_instance);
for (int i = 0; i < split_num_; i++) { for (int i = 0; i < split_num_; i++) {
device::cpu::MPIAdapter::Instance()->ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens, mpi_instance->ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens,
output_addr + i * reduce_scatter_out_lens, group, output_addr + i * reduce_scatter_out_lens, group, one_split_lens / 8, "sum");
one_split_lens / 8, "sum");
} }
} }
#endif #endif

View File

@ -46,9 +46,9 @@ bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto output_data_num = outputs[0]->size / sizeof(float); auto output_data_num = outputs[0]->size / sizeof(float);
auto mpi_instance = device::cpu::MPIAdapter::Instance();
return device::cpu::MPIAdapter::Instance()->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, MS_EXCEPTION_IF_NULL(mpi_instance);
op_type_); return mpi_instance->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, op_type_);
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore