forked from mindspore-Ecosystem/mindspore
!2552 check whether mpi instance is null
Merge pull request !2552 from chenjianping/host_reduce
This commit is contained in:
commit
927278be44
|
@ -60,7 +60,9 @@ std::string GetRankId() {
|
||||||
auto mpi_config_ptr = MpiConfig::GetInstance();
|
auto mpi_config_ptr = MpiConfig::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
|
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
|
||||||
if (mpi_config_ptr->enable_mpi()) {
|
if (mpi_config_ptr->enable_mpi()) {
|
||||||
int rank_id = device::cpu::MPIAdapter::Instance().GetRankId();
|
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||||
|
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||||
|
int rank_id = mpi_instance->GetRankId();
|
||||||
const char *offset = std::getenv("RANK_OFFSET");
|
const char *offset = std::getenv("RANK_OFFSET");
|
||||||
if (offset != nullptr) {
|
if (offset != nullptr) {
|
||||||
try {
|
try {
|
||||||
|
|
|
@ -46,8 +46,9 @@ bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||||
auto input_data_num = inputs[0]->size / sizeof(float);
|
auto input_data_num = inputs[0]->size / sizeof(float);
|
||||||
|
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||||
return device::cpu::MPIAdapter::Instance()->AllGather(input_addr, output_addr, ranks_group_, input_data_num);
|
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||||
|
return mpi_instance->AllGather(input_addr, output_addr, ranks_group_, input_data_num);
|
||||||
}
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -50,9 +50,11 @@ bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector<kernel::AddressP
|
||||||
const std::vector<int> &rank_group = {0, 1, 2, 3, 4, 5, 6, 7};
|
const std::vector<int> &rank_group = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||||
size_t input_split_lens = input_size / split_num_ / sizeof(float_t);
|
size_t input_split_lens = input_size / split_num_ / sizeof(float_t);
|
||||||
size_t output_split_lens = output_size / split_num_ / sizeof(float_t);
|
size_t output_split_lens = output_size / split_num_ / sizeof(float_t);
|
||||||
|
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||||
|
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||||
for (int i = 0; i < split_num_; i++) {
|
for (int i = 0; i < split_num_; i++) {
|
||||||
device::cpu::MPIAdapter::Instance()->AllGather(input_addr + i * input_split_lens,
|
mpi_instance->AllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group,
|
||||||
output_addr + i * output_split_lens, rank_group, input_split_lens);
|
input_split_lens);
|
||||||
}
|
}
|
||||||
#if defined(_WIN32) || defined(_WIN64)
|
#if defined(_WIN32) || defined(_WIN64)
|
||||||
auto end_time = std::chrono::steady_clock::now();
|
auto end_time = std::chrono::steady_clock::now();
|
||||||
|
|
|
@ -104,10 +104,11 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
||||||
size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float);
|
size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float);
|
||||||
size_t reduce_scatter_out_lens = one_split_lens / 8;
|
size_t reduce_scatter_out_lens = one_split_lens / 8;
|
||||||
const std::vector<int> &group = {0, 1, 2, 3, 4, 5, 6, 7};
|
const std::vector<int> &group = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||||
|
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||||
|
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||||
for (int i = 0; i < split_num_; i++) {
|
for (int i = 0; i < split_num_; i++) {
|
||||||
device::cpu::MPIAdapter::Instance()->ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens,
|
mpi_instance->ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens,
|
||||||
output_addr + i * reduce_scatter_out_lens, group,
|
output_addr + i * reduce_scatter_out_lens, group, one_split_lens / 8, "sum");
|
||||||
one_split_lens / 8, "sum");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -46,9 +46,9 @@ bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
|
||||||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||||
auto output_data_num = outputs[0]->size / sizeof(float);
|
auto output_data_num = outputs[0]->size / sizeof(float);
|
||||||
|
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||||
return device::cpu::MPIAdapter::Instance()->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num,
|
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||||
op_type_);
|
return mpi_instance->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, op_type_);
|
||||||
}
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue