forked from mindspore-Ecosystem/mindspore
!2552 check whether mpi instance is null
Merge pull request !2552 from chenjianping/host_reduce
This commit is contained in:
commit
927278be44
|
@ -60,7 +60,9 @@ std::string GetRankId() {
|
|||
auto mpi_config_ptr = MpiConfig::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
|
||||
if (mpi_config_ptr->enable_mpi()) {
|
||||
int rank_id = device::cpu::MPIAdapter::Instance().GetRankId();
|
||||
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||
int rank_id = mpi_instance->GetRankId();
|
||||
const char *offset = std::getenv("RANK_OFFSET");
|
||||
if (offset != nullptr) {
|
||||
try {
|
||||
|
|
|
@ -46,8 +46,9 @@ bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
auto input_data_num = inputs[0]->size / sizeof(float);
|
||||
|
||||
return device::cpu::MPIAdapter::Instance()->AllGather(input_addr, output_addr, ranks_group_, input_data_num);
|
||||
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||
return mpi_instance->AllGather(input_addr, output_addr, ranks_group_, input_data_num);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -50,9 +50,11 @@ bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector<kernel::AddressP
|
|||
const std::vector<int> &rank_group = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
size_t input_split_lens = input_size / split_num_ / sizeof(float_t);
|
||||
size_t output_split_lens = output_size / split_num_ / sizeof(float_t);
|
||||
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||
for (int i = 0; i < split_num_; i++) {
|
||||
device::cpu::MPIAdapter::Instance()->AllGather(input_addr + i * input_split_lens,
|
||||
output_addr + i * output_split_lens, rank_group, input_split_lens);
|
||||
mpi_instance->AllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group,
|
||||
input_split_lens);
|
||||
}
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto end_time = std::chrono::steady_clock::now();
|
||||
|
|
|
@ -104,10 +104,11 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|||
size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float);
|
||||
size_t reduce_scatter_out_lens = one_split_lens / 8;
|
||||
const std::vector<int> &group = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||
for (int i = 0; i < split_num_; i++) {
|
||||
device::cpu::MPIAdapter::Instance()->ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens,
|
||||
output_addr + i * reduce_scatter_out_lens, group,
|
||||
one_split_lens / 8, "sum");
|
||||
mpi_instance->ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens,
|
||||
output_addr + i * reduce_scatter_out_lens, group, one_split_lens / 8, "sum");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -46,9 +46,9 @@ bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
|
|||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
auto output_data_num = outputs[0]->size / sizeof(float);
|
||||
|
||||
return device::cpu::MPIAdapter::Instance()->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num,
|
||||
op_type_);
|
||||
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||
return mpi_instance->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, op_type_);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue