forked from OSSInnovation/mindspore
!2428 building _ms_mpi with mpi_interface
Merge pull request !2428 from chenjianping/host_reduce
This commit is contained in:
commit
cc2655e599
|
@ -128,6 +128,11 @@ if (ENABLE_MPI)
|
|||
DESTINATION ${INSTALL_BASE_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
install(
|
||||
TARGETS mpi_adapter
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (ENABLE_GPU)
|
||||
|
|
|
@ -145,11 +145,12 @@ if (ENABLE_DEBUGGER)
|
|||
endif()
|
||||
|
||||
target_link_libraries(mindspore proto_input)
|
||||
if (ENABLE_CPU AND ENABLE_MPI)
|
||||
target_link_libraries(mindspore securec mindspore::flatbuffers mindspore::ompi)
|
||||
if (ENABLE_MPI)
|
||||
target_link_libraries(mindspore securec mindspore::flatbuffers mpi_adapter)
|
||||
else ()
|
||||
target_link_libraries(mindspore securec mindspore::flatbuffers)
|
||||
endif ()
|
||||
|
||||
if (NOT WIN32)
|
||||
target_link_libraries(mindspore dl)
|
||||
endif()
|
||||
|
|
|
@ -14,17 +14,22 @@ endif ()
|
|||
|
||||
if (ENABLE_CPU)
|
||||
file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
|
||||
if (NOT ENABLE_MPI)
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc")
|
||||
endif ()
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc")
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_interface.cc")
|
||||
endif ()
|
||||
|
||||
if (ENABLE_MPI)
|
||||
# _ms_mpi
|
||||
set_property(SOURCE "gpu/mpi/mpi_initializer.cc"
|
||||
file(GLOB_RECURSE MPI_SRC_LIST "cpu/mpi/mpi_adapter.cc")
|
||||
set_property(SOURCE ${MPI_SRC_LIST}
|
||||
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
||||
pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc")
|
||||
target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi)
|
||||
add_library(mpi_adapter SHARED ${MPI_SRC_LIST})
|
||||
target_link_libraries(mpi_adapter PRIVATE mindspore::ompi)
|
||||
|
||||
set_property(SOURCE "cpu/mpi/mpi_interface.cc"
|
||||
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
||||
pybind11_add_module(_ms_mpi "cpu/mpi/mpi_interface.cc")
|
||||
target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mpi_adapter)
|
||||
endif ()
|
||||
|
||||
# gpu
|
||||
|
|
|
@ -15,13 +15,41 @@
|
|||
*/
|
||||
|
||||
#include "device/cpu/mpi/mpi_adapter.h"
|
||||
#ifdef ENABLE_MPI
|
||||
#include <algorithm>
|
||||
#include "utils/mpi/mpi_config.h"
|
||||
#include <sstream>
|
||||
#include "pybind11/pybind11.h"
|
||||
#endif // ENABLE_MPI
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
std::shared_ptr<MPIAdapter> MPIAdapter::instance_ = nullptr;
|
||||
std::shared_ptr<MPIAdapter> MPIAdapter::Instance() {
|
||||
if (instance_ == nullptr) {
|
||||
MS_LOG(DEBUG) << "Create new mpi adapter instance.";
|
||||
instance_.reset(new (std::nothrow) MPIAdapter());
|
||||
}
|
||||
return instance_;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MPI
|
||||
|
||||
#define RAISE_EXCEPTION(message) \
|
||||
{ \
|
||||
std::ostringstream oss; \
|
||||
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message; \
|
||||
pybind11::pybind11_fail(oss.str()); \
|
||||
}
|
||||
|
||||
#define RAISE_EXCEPTION_WITH_PARAM(message, param) \
|
||||
{ \
|
||||
std::ostringstream oss; \
|
||||
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message << param; \
|
||||
pybind11::pybind11_fail(oss.str()); \
|
||||
}
|
||||
|
||||
namespace {
|
||||
MPI_Op GetMpiOp(const std::string &op_type) {
|
||||
if (op_type == "sum") {
|
||||
|
@ -33,7 +61,8 @@ MPI_Op GetMpiOp(const std::string &op_type) {
|
|||
} else if (op_type == "prod") {
|
||||
return MPI_PROD;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "unsupport op_type:" << op_type;
|
||||
|
||||
RAISE_EXCEPTION_WITH_PARAM("unsupport op_type: ", op_type);
|
||||
return MPI_SUM;
|
||||
}
|
||||
|
||||
|
@ -46,80 +75,72 @@ int GetScatterIndex(int rankid, const std::vector<int> &ranks_group) {
|
|||
}
|
||||
}
|
||||
if (scatter_index == -1) {
|
||||
MS_LOG(EXCEPTION) << "process rankid " << rankid << " does not in the input rank group!";
|
||||
RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rankid);
|
||||
}
|
||||
return scatter_index;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MPIAdapter::MPIAdapter() : rank_id_(0), rank_size_(0), comm_group_world_(MPI_GROUP_NULL) { Init(); }
|
||||
MPIAdapter::MPIAdapter() : comm_group_world_(MPI_GROUP_NULL) { Init(); }
|
||||
|
||||
MPIAdapter::~MPIAdapter() {
|
||||
int finalized;
|
||||
MPI_Finalized(&finalized);
|
||||
if (finalized != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto iter = ranks_group_.begin(); iter != ranks_group_.end(); ++iter) {
|
||||
MPI_Group_free(&iter->second);
|
||||
}
|
||||
ranks_group_.clear();
|
||||
if (comm_group_world_ != MPI_GROUP_NULL) {
|
||||
MPI_Group_free(&comm_group_world_);
|
||||
comm_group_world_ = MPI_GROUP_NULL;
|
||||
}
|
||||
int finalized;
|
||||
MPI_Finalized(&finalized);
|
||||
if (finalized == 0) {
|
||||
MPI_Finalize();
|
||||
}
|
||||
MPI_Finalize();
|
||||
}
|
||||
|
||||
MPIAdapter &MPIAdapter::Instance() {
|
||||
static MPIAdapter instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
int MPIAdapter::GetRankId() const { return rank_id_; }
|
||||
|
||||
void MPIAdapter::Init() {
|
||||
static bool init = false;
|
||||
if (init) {
|
||||
return;
|
||||
}
|
||||
auto mpi_config_ptr = MpiConfig::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
|
||||
if (!mpi_config_ptr->enable_mpi()) {
|
||||
MS_LOG(EXCEPTION) << "MPI is disabled now!Please enable mpi with mpi config first.";
|
||||
}
|
||||
|
||||
int init_flag = 0;
|
||||
if (MPI_Initialized(&init_flag) != MPI_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Check mpi initialized fail!";
|
||||
RAISE_EXCEPTION("Check mpi initialized fail!");
|
||||
}
|
||||
if (init_flag == 0) {
|
||||
auto ret = MPI_Init(nullptr, nullptr);
|
||||
if (ret != MPI_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Failed to init mpi!";
|
||||
RAISE_EXCEPTION("Failed to init mpi!");
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_);
|
||||
if (comm_group_world_ == MPI_GROUP_NULL) {
|
||||
MS_LOG(EXCEPTION) << "comm_group_world_ init fail!";
|
||||
RAISE_EXCEPTION("comm_group_world_ init fail!");
|
||||
}
|
||||
auto ret = MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_);
|
||||
if (ret != MPI_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Failed to init mpi rank id!";
|
||||
RAISE_EXCEPTION("Failed to init mpi rank id!");
|
||||
}
|
||||
|
||||
ret = MPI_Comm_size(MPI_COMM_WORLD, &rank_size_);
|
||||
if (ret != MPI_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Failed to init mpi rank size!rankid:" << rank_id_;
|
||||
RAISE_EXCEPTION_WITH_PARAM("Failed to init mpi rank size!rankid:", rank_id_)
|
||||
}
|
||||
init = true;
|
||||
}
|
||||
|
||||
MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) {
|
||||
if (ranks.size() > static_cast<size_t>(rank_size_) || ranks.empty()) {
|
||||
MS_LOG(EXCEPTION) << "input rank size: " << ranks.size() << ", max rank size: " << rank_size_;
|
||||
RAISE_EXCEPTION_WITH_PARAM("input rank size:", ranks.size());
|
||||
}
|
||||
|
||||
if (std::find(ranks.begin(), ranks.end(), rank_id_) == ranks.end()) {
|
||||
MS_LOG(ERROR) << "rankid:" << rank_id_ << " is not in the input group.";
|
||||
return MPI_GROUP_NULL;
|
||||
RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rank_id_);
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(group_mutex_);
|
||||
auto iter = ranks_group_.find(ranks);
|
||||
|
@ -135,29 +156,28 @@ MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) {
|
|||
MPI_Group group = MPI_GROUP_NULL;
|
||||
MPI_Group_incl(comm_group_world_, ranks.size(), ranks_input.data(), &group);
|
||||
if (group == MPI_GROUP_NULL) {
|
||||
MS_LOG(EXCEPTION) << "create mpi group fail!rankid:" << rank_id_;
|
||||
RAISE_EXCEPTION_WITH_PARAM("create mpi group fail!rankid:", rank_id_)
|
||||
}
|
||||
|
||||
ranks_group_[ranks] = group;
|
||||
MS_LOG(INFO) << "rank:" << rank_id_ << " add group:" << group;
|
||||
return group;
|
||||
}
|
||||
|
||||
bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
|
||||
const std::string &op_type) {
|
||||
if (ranks_group.empty()) {
|
||||
MS_LOG(ERROR) << "input rank group is empty!";
|
||||
RAISE_EXCEPTION("input rank group is empty!");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto group = AddGroup(ranks_group);
|
||||
if (group == MPI_GROUP_NULL) {
|
||||
MS_LOG(EXCEPTION) << "Get mpi group fail!rankid:" << rank_id_;
|
||||
RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_)
|
||||
}
|
||||
MPI_Comm comm;
|
||||
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
|
||||
if (comm == MPI_COMM_NULL) {
|
||||
MS_LOG(EXCEPTION) << "create mpi comm fail!rankid:" << rank_id_;
|
||||
RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_);
|
||||
}
|
||||
std::vector<int> receive_count(ranks_group.size(), 0);
|
||||
for (size_t i = 0; i < ranks_group.size(); ++i) {
|
||||
|
@ -168,13 +188,13 @@ bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vec
|
|||
auto ret = MPI_Reduce_scatter(input, output, receive_count.data(), MPI_FLOAT, op, comm);
|
||||
bool result = true;
|
||||
if (ret != MPI_SUCCESS) {
|
||||
MS_LOG(ERROR) << "mpi reduce_scatter fail!ret = " << ret << ", rankid:" << rank_id_;
|
||||
RAISE_EXCEPTION_WITH_PARAM("mpi reduce_scatter fail!ret = ", ret);
|
||||
result = false;
|
||||
}
|
||||
|
||||
ret = MPI_Comm_free(&comm);
|
||||
if (ret != MPI_SUCCESS) {
|
||||
MS_LOG(WARNING) << "mpi comm free fail! ret = " << ret << ", rankid:" << rank_id_;
|
||||
RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail! ret = ", ret);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -184,19 +204,18 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
|
|||
int scatter_index = GetScatterIndex(rank_id_, ranks_group);
|
||||
auto group = AddGroup(ranks_group);
|
||||
if (group == MPI_GROUP_NULL) {
|
||||
MS_LOG(EXCEPTION) << "Get mpi group fail!rankid:" << rank_id_;
|
||||
RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_);
|
||||
}
|
||||
MPI_Comm comm;
|
||||
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
|
||||
if (comm == MPI_COMM_NULL) {
|
||||
MS_LOG(EXCEPTION) << "create mpi comm fail!rankid:" << rank_id_;
|
||||
RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_);
|
||||
}
|
||||
|
||||
MPI_Win window;
|
||||
auto ret = MPI_Win_create(input, input_data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window);
|
||||
if (ret != MPI_SUCCESS) {
|
||||
MS_LOG(ERROR) << "mpi window create fail! ret = " << ret;
|
||||
return false;
|
||||
RAISE_EXCEPTION_WITH_PARAM("mpi window create fail! ret = ", ret);
|
||||
}
|
||||
MPI_Win_fence(0, window);
|
||||
for (size_t i = 0; i < ranks_group.size(); ++i) {
|
||||
|
@ -208,18 +227,20 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
|
|||
ret = MPI_Accumulate(input + i * input_data_num, input_data_num, MPI_FLOAT, remote_rank, i * input_data_num,
|
||||
input_data_num, MPI_FLOAT, op, window);
|
||||
if (ret != MPI_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "mpi accumulate " << op_type << " fail!ret = " << ret;
|
||||
RAISE_EXCEPTION_WITH_PARAM("mpi accumulate fail!ret = ", ret);
|
||||
}
|
||||
}
|
||||
MPI_Win_fence(0, window);
|
||||
if (output != nullptr) {
|
||||
auto data_size = input_data_num * sizeof(float);
|
||||
if (output_size < data_size) {
|
||||
MS_LOG(EXCEPTION) << "output buffer size " << output_size << " < input size " << data_size;
|
||||
std::ostringstream exception_msg;
|
||||
exception_msg << "output buffer size " << output_size << " < input size " << data_size;
|
||||
RAISE_EXCEPTION(exception_msg.str())
|
||||
}
|
||||
auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size);
|
||||
if (copy_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "copy output memory fail!ret = " << copy_ret;
|
||||
RAISE_EXCEPTION_WITH_PARAM("copy output memory fail!ret = ", copy_ret);
|
||||
}
|
||||
}
|
||||
MPI_Win_free(&window);
|
||||
|
@ -229,31 +250,31 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
|
|||
|
||||
bool MPIAdapter::AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
|
||||
if (ranks_group.empty()) {
|
||||
MS_LOG(ERROR) << "input rank group is empty!";
|
||||
RAISE_EXCEPTION("input rank group is empty!");
|
||||
return false;
|
||||
}
|
||||
auto group = AddGroup(ranks_group);
|
||||
if (group == MPI_GROUP_NULL) {
|
||||
MS_LOG(EXCEPTION) << "Get mpi group fail! rankid:" << rank_id_;
|
||||
RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail! rankid:", rank_id_);
|
||||
}
|
||||
MPI_Comm comm;
|
||||
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
|
||||
if (comm == MPI_COMM_NULL) {
|
||||
MS_LOG(EXCEPTION) << "create mpi comm fail! rankid:" << rank_id_;
|
||||
RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail! rankid:", rank_id_);
|
||||
}
|
||||
|
||||
auto ret = MPI_Allgather(input, data_num, MPI_FLOAT, output, data_num, MPI_FLOAT, comm);
|
||||
bool result = true;
|
||||
|
||||
if (ret != MPI_SUCCESS) {
|
||||
MS_LOG(ERROR) << "mpi allgater fail!ret = " << ret << ", rankid:" << rank_id_;
|
||||
result = false;
|
||||
RAISE_EXCEPTION_WITH_PARAM("mpi allgater fail!ret = ", ret);
|
||||
}
|
||||
ret = MPI_Comm_free(&comm);
|
||||
if (ret != MPI_SUCCESS) {
|
||||
MS_LOG(WARNING) << "mpi comm free fail!ret = " << ret << ",rankid:" << rank_id_;
|
||||
RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail!ret = ", ret);
|
||||
}
|
||||
return result;
|
||||
return true;
|
||||
}
|
||||
#endif // ENABLE_MPI
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,37 +22,53 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include <mutex>
|
||||
#endif // ENABLE_MPI
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
#ifndef FUNC_EXPORT
|
||||
#define FUNC_EXPORT __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
constexpr auto kOpTypeSum = "sum";
|
||||
class MPIAdapter {
|
||||
public:
|
||||
~MPIAdapter();
|
||||
static MPIAdapter &Instance();
|
||||
int GetRankId() const;
|
||||
bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
|
||||
const std::string &op_type = kOpTypeSum);
|
||||
bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t input_data_num,
|
||||
size_t output_size, const std::string &op_type = kOpTypeSum,
|
||||
float *output = nullptr);
|
||||
bool AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);
|
||||
FUNC_EXPORT static std::shared_ptr<MPIAdapter> Instance();
|
||||
FUNC_EXPORT int GetRankId() const { return rank_id_; }
|
||||
FUNC_EXPORT int GetRankSize() const { return rank_size_; }
|
||||
#ifdef ENABLE_MPI
|
||||
FUNC_EXPORT ~MPIAdapter();
|
||||
FUNC_EXPORT bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group,
|
||||
size_t data_num, const std::string &op_type = kOpTypeSum);
|
||||
FUNC_EXPORT bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t in_data_num,
|
||||
size_t output_size, const std::string &op_type = kOpTypeSum,
|
||||
float *output = nullptr);
|
||||
FUNC_EXPORT bool AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);
|
||||
#else
|
||||
FUNC_EXPORT ~MPIAdapter() = default;
|
||||
#endif // ENABLE_MPI
|
||||
|
||||
private:
|
||||
#ifdef ENABLE_MPI
|
||||
MPIAdapter();
|
||||
void Init();
|
||||
MPI_Group AddGroup(const std::vector<int> &ranks);
|
||||
|
||||
int rank_id_;
|
||||
int rank_size_;
|
||||
MPI_Group comm_group_world_;
|
||||
// key:ranks group, value: mpi group
|
||||
std::map<std::vector<int>, MPI_Group> ranks_group_;
|
||||
std::mutex group_mutex_;
|
||||
#else
|
||||
MPIAdapter() = default;
|
||||
#endif // ENABLE_MPI
|
||||
int rank_id_{-1};
|
||||
int rank_size_{0};
|
||||
|
||||
static std::shared_ptr<MPIAdapter> instance_;
|
||||
};
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // ENABLE_MPI
|
||||
#endif // MINDSPORE_CCSRC_DEVICE_CPU_MPI_MPI_ADAPTER_H_
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <pybind11/operators.h>
|
||||
#include "device/cpu/mpi/mpi_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
int get_rank_id() { return MPIAdapter::Instance()->GetRankId(); }
|
||||
|
||||
int get_rank_size() { return MPIAdapter::Instance()->GetRankSize(); }
|
||||
|
||||
PYBIND11_MODULE(_ms_mpi, mpi_interface) {
|
||||
mpi_interface.doc() = "mindspore mpi python wrapper";
|
||||
mpi_interface.def("get_rank_id", &get_rank_id, "get rank id");
|
||||
mpi_interface.def("get_rank_size", &get_rank_size, "get rank size");
|
||||
}
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -17,7 +17,6 @@
|
|||
#include "device/gpu/mpi/mpi_initializer.h"
|
||||
|
||||
#include <mpi.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -54,12 +53,6 @@ MPIInitializer &MPIInitializer::GetInstance() {
|
|||
int MPIInitializer::get_rank_id() { return MPIInitializer::GetInstance().rank_id_; }
|
||||
|
||||
int MPIInitializer::get_rank_size() { return MPIInitializer::GetInstance().rank_size_; }
|
||||
|
||||
PYBIND11_MODULE(_ms_mpi, mpi_initializer) {
|
||||
mpi_initializer.doc() = "mindspore mpi python wrapper";
|
||||
mpi_initializer.def("get_rank_id", &MPIInitializer::get_rank_id, "get rank id");
|
||||
mpi_initializer.def("get_rank_size", &MPIInitializer::get_rank_size, "get rank size");
|
||||
}
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,7 +47,7 @@ bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
auto input_data_num = inputs[0]->size / sizeof(float);
|
||||
|
||||
return device::cpu::MPIAdapter::Instance().AllGather(input_addr, output_addr, ranks_group_, input_data_num);
|
||||
return device::cpu::MPIAdapter::Instance()->AllGather(input_addr, output_addr, ranks_group_, input_data_num);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -51,8 +51,8 @@ bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector<kernel::AddressP
|
|||
size_t input_split_lens = input_size / split_num_ / sizeof(float_t);
|
||||
size_t output_split_lens = output_size / split_num_ / sizeof(float_t);
|
||||
for (int i = 0; i < split_num_; i++) {
|
||||
device::cpu::MPIAdapter::Instance().AllGather(input_addr + i * input_split_lens,
|
||||
output_addr + i * output_split_lens, rank_group, input_split_lens);
|
||||
device::cpu::MPIAdapter::Instance()->AllGather(input_addr + i * input_split_lens,
|
||||
output_addr + i * output_split_lens, rank_group, input_split_lens);
|
||||
}
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto end_time = std::chrono::steady_clock::now();
|
||||
|
|
|
@ -105,9 +105,9 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|||
size_t reduce_scatter_out_lens = one_split_lens / 8;
|
||||
const std::vector<int> &group = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
for (int i = 0; i < split_num_; i++) {
|
||||
device::cpu::MPIAdapter::Instance().ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens,
|
||||
output_addr + i * reduce_scatter_out_lens, group,
|
||||
one_split_lens / 8, "sum");
|
||||
device::cpu::MPIAdapter::Instance()->ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens,
|
||||
output_addr + i * reduce_scatter_out_lens, group,
|
||||
one_split_lens / 8, "sum");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -47,8 +47,8 @@ bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
|
|||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
auto output_data_num = outputs[0]->size / sizeof(float);
|
||||
|
||||
return device::cpu::MPIAdapter::Instance().ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num,
|
||||
op_type_);
|
||||
return device::cpu::MPIAdapter::Instance()->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num,
|
||||
op_type_);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,7 +25,6 @@ from mindspore._c_expression import MSContext
|
|||
from mindspore._checkparam import args_type_check
|
||||
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
|
||||
_reset_auto_parallel_context
|
||||
from mindspore.parallel.mpi._mpi_config import _set_mpi_config, _get_mpi_config
|
||||
|
||||
__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
|
||||
'get_auto_parallel_context', 'reset_auto_parallel_context']
|
||||
|
@ -608,40 +607,3 @@ def get_context(attr_key):
|
|||
raise ValueError(
|
||||
"Get context keyword %s is not recognized!" % attr_key)
|
||||
return getattr(_context(), attr_key)
|
||||
|
||||
@args_type_check(enable_mpi=bool)
|
||||
def set_mpi_config(**kwargs):
|
||||
"""
|
||||
Sets mpi config for running environment.
|
||||
|
||||
mpi config should be configured before running your program. If there is no configuration,
|
||||
mpi moudle will be disabled by default.
|
||||
|
||||
Note:
|
||||
Attribute name is required for setting attributes.
|
||||
|
||||
Args:
|
||||
enable_mpi (bool): Whether to enable mpi. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in mpi config.
|
||||
|
||||
Examples:
|
||||
>>> mpiconfig.set_mpi_config(enable_mpi=True)
|
||||
"""
|
||||
_set_mpi_config(**kwargs)
|
||||
|
||||
def get_mpi_config(attr_key):
|
||||
"""
|
||||
Gets mpi config attribute value according to the input key.
|
||||
|
||||
Args:
|
||||
attr_key (str): The key of the attribute.
|
||||
|
||||
Returns:
|
||||
Object, The value of given attribute key.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
"""
|
||||
return _get_mpi_config(attr_key)
|
||||
|
|
|
@ -104,7 +104,7 @@ def _get_mpi_config(attr_key):
|
|||
Object, The value of given attribute key.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
ValueError: If input key is not an attribute in config.
|
||||
"""
|
||||
if not hasattr(_mpi_config(), attr_key):
|
||||
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
|
||||
|
|
Loading…
Reference in New Issue