building _ms_mpi with mpi_interface
This commit is contained in:
parent
c9b8a8da0a
commit
343889cdb7
|
@ -128,6 +128,11 @@ if (ENABLE_MPI)
|
||||||
DESTINATION ${INSTALL_BASE_DIR}
|
DESTINATION ${INSTALL_BASE_DIR}
|
||||||
COMPONENT mindspore
|
COMPONENT mindspore
|
||||||
)
|
)
|
||||||
|
install(
|
||||||
|
TARGETS mpi_adapter
|
||||||
|
DESTINATION ${INSTALL_LIB_DIR}
|
||||||
|
COMPONENT mindspore
|
||||||
|
)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
if (ENABLE_GPU)
|
if (ENABLE_GPU)
|
||||||
|
|
|
@ -126,11 +126,12 @@ endforeach ()
|
||||||
set_property(SOURCE ${SUB_OBJECTS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME)
|
set_property(SOURCE ${SUB_OBJECTS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME)
|
||||||
add_library(mindspore STATIC ${SUB_OBJECTS_SRC})
|
add_library(mindspore STATIC ${SUB_OBJECTS_SRC})
|
||||||
target_link_libraries(mindspore proto_input)
|
target_link_libraries(mindspore proto_input)
|
||||||
if (ENABLE_CPU AND ENABLE_MPI)
|
if (ENABLE_MPI)
|
||||||
target_link_libraries(mindspore securec mindspore::flatbuffers mindspore::ompi)
|
target_link_libraries(mindspore securec mindspore::flatbuffers mpi_adapter)
|
||||||
else ()
|
else ()
|
||||||
target_link_libraries(mindspore securec mindspore::flatbuffers)
|
target_link_libraries(mindspore securec mindspore::flatbuffers)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
if (NOT WIN32)
|
if (NOT WIN32)
|
||||||
target_link_libraries(mindspore dl)
|
target_link_libraries(mindspore dl)
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -14,17 +14,22 @@ endif ()
|
||||||
|
|
||||||
if (ENABLE_CPU)
|
if (ENABLE_CPU)
|
||||||
file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
|
file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
|
||||||
if (NOT ENABLE_MPI)
|
list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc")
|
||||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc")
|
list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_interface.cc")
|
||||||
endif ()
|
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
if (ENABLE_MPI)
|
if (ENABLE_MPI)
|
||||||
# _ms_mpi
|
# _ms_mpi
|
||||||
set_property(SOURCE "gpu/mpi/mpi_initializer.cc"
|
file(GLOB_RECURSE MPI_SRC_LIST "cpu/mpi/mpi_adapter.cc")
|
||||||
|
set_property(SOURCE ${MPI_SRC_LIST}
|
||||||
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
||||||
pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc")
|
add_library(mpi_adapter SHARED ${MPI_SRC_LIST})
|
||||||
target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi)
|
target_link_libraries(mpi_adapter PRIVATE mindspore::ompi)
|
||||||
|
|
||||||
|
set_property(SOURCE "cpu/mpi/mpi_interface.cc"
|
||||||
|
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
||||||
|
pybind11_add_module(_ms_mpi "cpu/mpi/mpi_interface.cc")
|
||||||
|
target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mpi_adapter)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
# gpu
|
# gpu
|
||||||
|
|
|
@ -15,13 +15,41 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "device/cpu/mpi/mpi_adapter.h"
|
#include "device/cpu/mpi/mpi_adapter.h"
|
||||||
|
#ifdef ENABLE_MPI
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "utils/mpi/mpi_config.h"
|
#include <sstream>
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#endif // ENABLE_MPI
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace cpu {
|
namespace cpu {
|
||||||
|
std::shared_ptr<MPIAdapter> MPIAdapter::instance_ = nullptr;
|
||||||
|
std::shared_ptr<MPIAdapter> MPIAdapter::Instance() {
|
||||||
|
if (instance_ == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "Create new mpi adapter instance.";
|
||||||
|
instance_.reset(new (std::nothrow) MPIAdapter());
|
||||||
|
}
|
||||||
|
return instance_;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef ENABLE_MPI
|
||||||
|
|
||||||
|
#define RAISE_EXCEPTION(message) \
|
||||||
|
{ \
|
||||||
|
std::ostringstream oss; \
|
||||||
|
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message; \
|
||||||
|
pybind11::pybind11_fail(oss.str()); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define RAISE_EXCEPTION_WITH_PARAM(message, param) \
|
||||||
|
{ \
|
||||||
|
std::ostringstream oss; \
|
||||||
|
oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message << param; \
|
||||||
|
pybind11::pybind11_fail(oss.str()); \
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
MPI_Op GetMpiOp(const std::string &op_type) {
|
MPI_Op GetMpiOp(const std::string &op_type) {
|
||||||
if (op_type == "sum") {
|
if (op_type == "sum") {
|
||||||
|
@ -33,7 +61,8 @@ MPI_Op GetMpiOp(const std::string &op_type) {
|
||||||
} else if (op_type == "prod") {
|
} else if (op_type == "prod") {
|
||||||
return MPI_PROD;
|
return MPI_PROD;
|
||||||
}
|
}
|
||||||
MS_LOG(EXCEPTION) << "unsupport op_type:" << op_type;
|
|
||||||
|
RAISE_EXCEPTION_WITH_PARAM("unsupport op_type: ", op_type);
|
||||||
return MPI_SUM;
|
return MPI_SUM;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,80 +75,72 @@ int GetScatterIndex(int rankid, const std::vector<int> &ranks_group) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (scatter_index == -1) {
|
if (scatter_index == -1) {
|
||||||
MS_LOG(EXCEPTION) << "process rankid " << rankid << " does not in the input rank group!";
|
RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rankid);
|
||||||
}
|
}
|
||||||
return scatter_index;
|
return scatter_index;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
MPIAdapter::MPIAdapter() : rank_id_(0), rank_size_(0), comm_group_world_(MPI_GROUP_NULL) { Init(); }
|
MPIAdapter::MPIAdapter() : comm_group_world_(MPI_GROUP_NULL) { Init(); }
|
||||||
|
|
||||||
MPIAdapter::~MPIAdapter() {
|
MPIAdapter::~MPIAdapter() {
|
||||||
|
int finalized;
|
||||||
|
MPI_Finalized(&finalized);
|
||||||
|
if (finalized != 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
for (auto iter = ranks_group_.begin(); iter != ranks_group_.end(); ++iter) {
|
for (auto iter = ranks_group_.begin(); iter != ranks_group_.end(); ++iter) {
|
||||||
MPI_Group_free(&iter->second);
|
MPI_Group_free(&iter->second);
|
||||||
}
|
}
|
||||||
|
ranks_group_.clear();
|
||||||
if (comm_group_world_ != MPI_GROUP_NULL) {
|
if (comm_group_world_ != MPI_GROUP_NULL) {
|
||||||
MPI_Group_free(&comm_group_world_);
|
MPI_Group_free(&comm_group_world_);
|
||||||
|
comm_group_world_ = MPI_GROUP_NULL;
|
||||||
}
|
}
|
||||||
int finalized;
|
MPI_Finalize();
|
||||||
MPI_Finalized(&finalized);
|
|
||||||
if (finalized == 0) {
|
|
||||||
MPI_Finalize();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MPIAdapter &MPIAdapter::Instance() {
|
|
||||||
static MPIAdapter instance;
|
|
||||||
return instance;
|
|
||||||
}
|
|
||||||
|
|
||||||
int MPIAdapter::GetRankId() const { return rank_id_; }
|
|
||||||
|
|
||||||
void MPIAdapter::Init() {
|
void MPIAdapter::Init() {
|
||||||
static bool init = false;
|
static bool init = false;
|
||||||
if (init) {
|
if (init) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto mpi_config_ptr = MpiConfig::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
|
|
||||||
if (!mpi_config_ptr->enable_mpi()) {
|
|
||||||
MS_LOG(EXCEPTION) << "MPI is disabled now!Please enable mpi with mpi config first.";
|
|
||||||
}
|
|
||||||
int init_flag = 0;
|
int init_flag = 0;
|
||||||
if (MPI_Initialized(&init_flag) != MPI_SUCCESS) {
|
if (MPI_Initialized(&init_flag) != MPI_SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << "Check mpi initialized fail!";
|
RAISE_EXCEPTION("Check mpi initialized fail!");
|
||||||
}
|
}
|
||||||
if (init_flag == 0) {
|
if (init_flag == 0) {
|
||||||
auto ret = MPI_Init(nullptr, nullptr);
|
auto ret = MPI_Init(nullptr, nullptr);
|
||||||
if (ret != MPI_SUCCESS) {
|
if (ret != MPI_SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << "Failed to init mpi!";
|
RAISE_EXCEPTION("Failed to init mpi!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_);
|
MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_);
|
||||||
if (comm_group_world_ == MPI_GROUP_NULL) {
|
if (comm_group_world_ == MPI_GROUP_NULL) {
|
||||||
MS_LOG(EXCEPTION) << "comm_group_world_ init fail!";
|
RAISE_EXCEPTION("comm_group_world_ init fail!");
|
||||||
}
|
}
|
||||||
auto ret = MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_);
|
auto ret = MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_);
|
||||||
if (ret != MPI_SUCCESS) {
|
if (ret != MPI_SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << "Failed to init mpi rank id!";
|
RAISE_EXCEPTION("Failed to init mpi rank id!");
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = MPI_Comm_size(MPI_COMM_WORLD, &rank_size_);
|
ret = MPI_Comm_size(MPI_COMM_WORLD, &rank_size_);
|
||||||
if (ret != MPI_SUCCESS) {
|
if (ret != MPI_SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << "Failed to init mpi rank size!rankid:" << rank_id_;
|
RAISE_EXCEPTION_WITH_PARAM("Failed to init mpi rank size!rankid:", rank_id_)
|
||||||
}
|
}
|
||||||
init = true;
|
init = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) {
|
MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) {
|
||||||
if (ranks.size() > static_cast<size_t>(rank_size_) || ranks.empty()) {
|
if (ranks.size() > static_cast<size_t>(rank_size_) || ranks.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "input rank size: " << ranks.size() << ", max rank size: " << rank_size_;
|
RAISE_EXCEPTION_WITH_PARAM("input rank size:", ranks.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (std::find(ranks.begin(), ranks.end(), rank_id_) == ranks.end()) {
|
if (std::find(ranks.begin(), ranks.end(), rank_id_) == ranks.end()) {
|
||||||
MS_LOG(ERROR) << "rankid:" << rank_id_ << " is not in the input group.";
|
RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rank_id_);
|
||||||
return MPI_GROUP_NULL;
|
|
||||||
}
|
}
|
||||||
std::lock_guard<std::mutex> lock(group_mutex_);
|
std::lock_guard<std::mutex> lock(group_mutex_);
|
||||||
auto iter = ranks_group_.find(ranks);
|
auto iter = ranks_group_.find(ranks);
|
||||||
|
@ -135,29 +156,28 @@ MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) {
|
||||||
MPI_Group group = MPI_GROUP_NULL;
|
MPI_Group group = MPI_GROUP_NULL;
|
||||||
MPI_Group_incl(comm_group_world_, ranks.size(), ranks_input.data(), &group);
|
MPI_Group_incl(comm_group_world_, ranks.size(), ranks_input.data(), &group);
|
||||||
if (group == MPI_GROUP_NULL) {
|
if (group == MPI_GROUP_NULL) {
|
||||||
MS_LOG(EXCEPTION) << "create mpi group fail!rankid:" << rank_id_;
|
RAISE_EXCEPTION_WITH_PARAM("create mpi group fail!rankid:", rank_id_)
|
||||||
}
|
}
|
||||||
|
|
||||||
ranks_group_[ranks] = group;
|
ranks_group_[ranks] = group;
|
||||||
MS_LOG(INFO) << "rank:" << rank_id_ << " add group:" << group;
|
|
||||||
return group;
|
return group;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
|
bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
|
||||||
const std::string &op_type) {
|
const std::string &op_type) {
|
||||||
if (ranks_group.empty()) {
|
if (ranks_group.empty()) {
|
||||||
MS_LOG(ERROR) << "input rank group is empty!";
|
RAISE_EXCEPTION("input rank group is empty!");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto group = AddGroup(ranks_group);
|
auto group = AddGroup(ranks_group);
|
||||||
if (group == MPI_GROUP_NULL) {
|
if (group == MPI_GROUP_NULL) {
|
||||||
MS_LOG(EXCEPTION) << "Get mpi group fail!rankid:" << rank_id_;
|
RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_)
|
||||||
}
|
}
|
||||||
MPI_Comm comm;
|
MPI_Comm comm;
|
||||||
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
|
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
|
||||||
if (comm == MPI_COMM_NULL) {
|
if (comm == MPI_COMM_NULL) {
|
||||||
MS_LOG(EXCEPTION) << "create mpi comm fail!rankid:" << rank_id_;
|
RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_);
|
||||||
}
|
}
|
||||||
std::vector<int> receive_count(ranks_group.size(), 0);
|
std::vector<int> receive_count(ranks_group.size(), 0);
|
||||||
for (size_t i = 0; i < ranks_group.size(); ++i) {
|
for (size_t i = 0; i < ranks_group.size(); ++i) {
|
||||||
|
@ -168,13 +188,13 @@ bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vec
|
||||||
auto ret = MPI_Reduce_scatter(input, output, receive_count.data(), MPI_FLOAT, op, comm);
|
auto ret = MPI_Reduce_scatter(input, output, receive_count.data(), MPI_FLOAT, op, comm);
|
||||||
bool result = true;
|
bool result = true;
|
||||||
if (ret != MPI_SUCCESS) {
|
if (ret != MPI_SUCCESS) {
|
||||||
MS_LOG(ERROR) << "mpi reduce_scatter fail!ret = " << ret << ", rankid:" << rank_id_;
|
RAISE_EXCEPTION_WITH_PARAM("mpi reduce_scatter fail!ret = ", ret);
|
||||||
result = false;
|
result = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = MPI_Comm_free(&comm);
|
ret = MPI_Comm_free(&comm);
|
||||||
if (ret != MPI_SUCCESS) {
|
if (ret != MPI_SUCCESS) {
|
||||||
MS_LOG(WARNING) << "mpi comm free fail! ret = " << ret << ", rankid:" << rank_id_;
|
RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail! ret = ", ret);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -184,19 +204,18 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
|
||||||
int scatter_index = GetScatterIndex(rank_id_, ranks_group);
|
int scatter_index = GetScatterIndex(rank_id_, ranks_group);
|
||||||
auto group = AddGroup(ranks_group);
|
auto group = AddGroup(ranks_group);
|
||||||
if (group == MPI_GROUP_NULL) {
|
if (group == MPI_GROUP_NULL) {
|
||||||
MS_LOG(EXCEPTION) << "Get mpi group fail!rankid:" << rank_id_;
|
RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_);
|
||||||
}
|
}
|
||||||
MPI_Comm comm;
|
MPI_Comm comm;
|
||||||
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
|
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
|
||||||
if (comm == MPI_COMM_NULL) {
|
if (comm == MPI_COMM_NULL) {
|
||||||
MS_LOG(EXCEPTION) << "create mpi comm fail!rankid:" << rank_id_;
|
RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_);
|
||||||
}
|
}
|
||||||
|
|
||||||
MPI_Win window;
|
MPI_Win window;
|
||||||
auto ret = MPI_Win_create(input, input_data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window);
|
auto ret = MPI_Win_create(input, input_data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window);
|
||||||
if (ret != MPI_SUCCESS) {
|
if (ret != MPI_SUCCESS) {
|
||||||
MS_LOG(ERROR) << "mpi window create fail! ret = " << ret;
|
RAISE_EXCEPTION_WITH_PARAM("mpi window create fail! ret = ", ret);
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
MPI_Win_fence(0, window);
|
MPI_Win_fence(0, window);
|
||||||
for (size_t i = 0; i < ranks_group.size(); ++i) {
|
for (size_t i = 0; i < ranks_group.size(); ++i) {
|
||||||
|
@ -208,18 +227,20 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
|
||||||
ret = MPI_Accumulate(input + i * input_data_num, input_data_num, MPI_FLOAT, remote_rank, i * input_data_num,
|
ret = MPI_Accumulate(input + i * input_data_num, input_data_num, MPI_FLOAT, remote_rank, i * input_data_num,
|
||||||
input_data_num, MPI_FLOAT, op, window);
|
input_data_num, MPI_FLOAT, op, window);
|
||||||
if (ret != MPI_SUCCESS) {
|
if (ret != MPI_SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << "mpi accumulate " << op_type << " fail!ret = " << ret;
|
RAISE_EXCEPTION_WITH_PARAM("mpi accumulate fail!ret = ", ret);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MPI_Win_fence(0, window);
|
MPI_Win_fence(0, window);
|
||||||
if (output != nullptr) {
|
if (output != nullptr) {
|
||||||
auto data_size = input_data_num * sizeof(float);
|
auto data_size = input_data_num * sizeof(float);
|
||||||
if (output_size < data_size) {
|
if (output_size < data_size) {
|
||||||
MS_LOG(EXCEPTION) << "output buffer size " << output_size << " < input size " << data_size;
|
std::ostringstream exception_msg;
|
||||||
|
exception_msg << "output buffer size " << output_size << " < input size " << data_size;
|
||||||
|
RAISE_EXCEPTION(exception_msg.str())
|
||||||
}
|
}
|
||||||
auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size);
|
auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size);
|
||||||
if (copy_ret != 0) {
|
if (copy_ret != 0) {
|
||||||
MS_LOG(EXCEPTION) << "copy output memory fail!ret = " << copy_ret;
|
RAISE_EXCEPTION_WITH_PARAM("copy output memory fail!ret = ", copy_ret);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MPI_Win_free(&window);
|
MPI_Win_free(&window);
|
||||||
|
@ -229,31 +250,31 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
|
||||||
|
|
||||||
bool MPIAdapter::AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
|
bool MPIAdapter::AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
|
||||||
if (ranks_group.empty()) {
|
if (ranks_group.empty()) {
|
||||||
MS_LOG(ERROR) << "input rank group is empty!";
|
RAISE_EXCEPTION("input rank group is empty!");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto group = AddGroup(ranks_group);
|
auto group = AddGroup(ranks_group);
|
||||||
if (group == MPI_GROUP_NULL) {
|
if (group == MPI_GROUP_NULL) {
|
||||||
MS_LOG(EXCEPTION) << "Get mpi group fail! rankid:" << rank_id_;
|
RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail! rankid:", rank_id_);
|
||||||
}
|
}
|
||||||
MPI_Comm comm;
|
MPI_Comm comm;
|
||||||
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
|
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
|
||||||
if (comm == MPI_COMM_NULL) {
|
if (comm == MPI_COMM_NULL) {
|
||||||
MS_LOG(EXCEPTION) << "create mpi comm fail! rankid:" << rank_id_;
|
RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail! rankid:", rank_id_);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ret = MPI_Allgather(input, data_num, MPI_FLOAT, output, data_num, MPI_FLOAT, comm);
|
auto ret = MPI_Allgather(input, data_num, MPI_FLOAT, output, data_num, MPI_FLOAT, comm);
|
||||||
bool result = true;
|
|
||||||
if (ret != MPI_SUCCESS) {
|
if (ret != MPI_SUCCESS) {
|
||||||
MS_LOG(ERROR) << "mpi allgater fail!ret = " << ret << ", rankid:" << rank_id_;
|
RAISE_EXCEPTION_WITH_PARAM("mpi allgater fail!ret = ", ret);
|
||||||
result = false;
|
|
||||||
}
|
}
|
||||||
ret = MPI_Comm_free(&comm);
|
ret = MPI_Comm_free(&comm);
|
||||||
if (ret != MPI_SUCCESS) {
|
if (ret != MPI_SUCCESS) {
|
||||||
MS_LOG(WARNING) << "mpi comm free fail!ret = " << ret << ",rankid:" << rank_id_;
|
RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail!ret = ", ret);
|
||||||
}
|
}
|
||||||
return result;
|
return true;
|
||||||
}
|
}
|
||||||
|
#endif // ENABLE_MPI
|
||||||
} // namespace cpu
|
} // namespace cpu
|
||||||
} // namespace device
|
} // namespace device
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -22,37 +22,53 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#endif // ENABLE_MPI
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace cpu {
|
namespace cpu {
|
||||||
|
#ifndef FUNC_EXPORT
|
||||||
|
#define FUNC_EXPORT __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
constexpr auto kOpTypeSum = "sum";
|
constexpr auto kOpTypeSum = "sum";
|
||||||
class MPIAdapter {
|
class MPIAdapter {
|
||||||
public:
|
public:
|
||||||
~MPIAdapter();
|
FUNC_EXPORT static std::shared_ptr<MPIAdapter> Instance();
|
||||||
static MPIAdapter &Instance();
|
FUNC_EXPORT int GetRankId() const { return rank_id_; }
|
||||||
int GetRankId() const;
|
FUNC_EXPORT int GetRankSize() const { return rank_size_; }
|
||||||
bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
|
#ifdef ENABLE_MPI
|
||||||
const std::string &op_type = kOpTypeSum);
|
FUNC_EXPORT ~MPIAdapter();
|
||||||
bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t input_data_num,
|
FUNC_EXPORT bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group,
|
||||||
size_t output_size, const std::string &op_type = kOpTypeSum,
|
size_t data_num, const std::string &op_type = kOpTypeSum);
|
||||||
float *output = nullptr);
|
FUNC_EXPORT bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t in_data_num,
|
||||||
bool AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);
|
size_t output_size, const std::string &op_type = kOpTypeSum,
|
||||||
|
float *output = nullptr);
|
||||||
|
FUNC_EXPORT bool AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);
|
||||||
|
#else
|
||||||
|
FUNC_EXPORT ~MPIAdapter() = default;
|
||||||
|
#endif // ENABLE_MPI
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
#ifdef ENABLE_MPI
|
||||||
MPIAdapter();
|
MPIAdapter();
|
||||||
void Init();
|
void Init();
|
||||||
MPI_Group AddGroup(const std::vector<int> &ranks);
|
MPI_Group AddGroup(const std::vector<int> &ranks);
|
||||||
|
|
||||||
int rank_id_;
|
|
||||||
int rank_size_;
|
|
||||||
MPI_Group comm_group_world_;
|
MPI_Group comm_group_world_;
|
||||||
// key:ranks group, value: mpi group
|
// key:ranks group, value: mpi group
|
||||||
std::map<std::vector<int>, MPI_Group> ranks_group_;
|
std::map<std::vector<int>, MPI_Group> ranks_group_;
|
||||||
std::mutex group_mutex_;
|
std::mutex group_mutex_;
|
||||||
|
#else
|
||||||
|
MPIAdapter() = default;
|
||||||
|
#endif // ENABLE_MPI
|
||||||
|
int rank_id_{-1};
|
||||||
|
int rank_size_{0};
|
||||||
|
|
||||||
|
static std::shared_ptr<MPIAdapter> instance_;
|
||||||
};
|
};
|
||||||
} // namespace cpu
|
} // namespace cpu
|
||||||
} // namespace device
|
} // namespace device
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // ENABLE_MPI
|
|
||||||
#endif // MINDSPORE_CCSRC_DEVICE_CPU_MPI_MPI_ADAPTER_H_
|
#endif // MINDSPORE_CCSRC_DEVICE_CPU_MPI_MPI_ADAPTER_H_
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <pybind11/operators.h>
|
||||||
|
#include "device/cpu/mpi/mpi_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace cpu {
|
||||||
|
int get_rank_id() { return MPIAdapter::Instance()->GetRankId(); }
|
||||||
|
|
||||||
|
int get_rank_size() { return MPIAdapter::Instance()->GetRankSize(); }
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_ms_mpi, mpi_interface) {
|
||||||
|
mpi_interface.doc() = "mindspore mpi python wrapper";
|
||||||
|
mpi_interface.def("get_rank_id", &get_rank_id, "get rank id");
|
||||||
|
mpi_interface.def("get_rank_size", &get_rank_size, "get rank size");
|
||||||
|
}
|
||||||
|
} // namespace cpu
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
|
@ -17,7 +17,6 @@
|
||||||
#include "device/gpu/mpi/mpi_initializer.h"
|
#include "device/gpu/mpi/mpi_initializer.h"
|
||||||
|
|
||||||
#include <mpi.h>
|
#include <mpi.h>
|
||||||
#include <pybind11/operators.h>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -54,12 +53,6 @@ MPIInitializer &MPIInitializer::GetInstance() {
|
||||||
int MPIInitializer::get_rank_id() { return MPIInitializer::GetInstance().rank_id_; }
|
int MPIInitializer::get_rank_id() { return MPIInitializer::GetInstance().rank_id_; }
|
||||||
|
|
||||||
int MPIInitializer::get_rank_size() { return MPIInitializer::GetInstance().rank_size_; }
|
int MPIInitializer::get_rank_size() { return MPIInitializer::GetInstance().rank_size_; }
|
||||||
|
|
||||||
PYBIND11_MODULE(_ms_mpi, mpi_initializer) {
|
|
||||||
mpi_initializer.doc() = "mindspore mpi python wrapper";
|
|
||||||
mpi_initializer.def("get_rank_id", &MPIInitializer::get_rank_id, "get rank id");
|
|
||||||
mpi_initializer.def("get_rank_size", &MPIInitializer::get_rank_size, "get rank size");
|
|
||||||
}
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace device
|
} // namespace device
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -47,7 +47,7 @@ bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||||
auto input_data_num = inputs[0]->size / sizeof(float);
|
auto input_data_num = inputs[0]->size / sizeof(float);
|
||||||
|
|
||||||
return device::cpu::MPIAdapter::Instance().AllGather(input_addr, output_addr, ranks_group_, input_data_num);
|
return device::cpu::MPIAdapter::Instance()->AllGather(input_addr, output_addr, ranks_group_, input_data_num);
|
||||||
}
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -51,8 +51,8 @@ bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector<kernel::AddressP
|
||||||
size_t input_split_lens = input_size / split_num_ / sizeof(float_t);
|
size_t input_split_lens = input_size / split_num_ / sizeof(float_t);
|
||||||
size_t output_split_lens = output_size / split_num_ / sizeof(float_t);
|
size_t output_split_lens = output_size / split_num_ / sizeof(float_t);
|
||||||
for (int i = 0; i < split_num_; i++) {
|
for (int i = 0; i < split_num_; i++) {
|
||||||
device::cpu::MPIAdapter::Instance().AllGather(input_addr + i * input_split_lens,
|
device::cpu::MPIAdapter::Instance()->AllGather(input_addr + i * input_split_lens,
|
||||||
output_addr + i * output_split_lens, rank_group, input_split_lens);
|
output_addr + i * output_split_lens, rank_group, input_split_lens);
|
||||||
}
|
}
|
||||||
#if defined(_WIN32) || defined(_WIN64)
|
#if defined(_WIN32) || defined(_WIN64)
|
||||||
auto end_time = std::chrono::steady_clock::now();
|
auto end_time = std::chrono::steady_clock::now();
|
||||||
|
|
|
@ -105,9 +105,9 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
||||||
size_t reduce_scatter_out_lens = one_split_lens / 8;
|
size_t reduce_scatter_out_lens = one_split_lens / 8;
|
||||||
const std::vector<int> &group = {0, 1, 2, 3, 4, 5, 6, 7};
|
const std::vector<int> &group = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||||
for (int i = 0; i < split_num_; i++) {
|
for (int i = 0; i < split_num_; i++) {
|
||||||
device::cpu::MPIAdapter::Instance().ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens,
|
device::cpu::MPIAdapter::Instance()->ReduceScatter(reinterpret_cast<float *>(gather_v2_out_) + i * one_split_lens,
|
||||||
output_addr + i * reduce_scatter_out_lens, group,
|
output_addr + i * reduce_scatter_out_lens, group,
|
||||||
one_split_lens / 8, "sum");
|
one_split_lens / 8, "sum");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -47,8 +47,8 @@ bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
|
||||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||||
auto output_data_num = outputs[0]->size / sizeof(float);
|
auto output_data_num = outputs[0]->size / sizeof(float);
|
||||||
|
|
||||||
return device::cpu::MPIAdapter::Instance().ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num,
|
return device::cpu::MPIAdapter::Instance()->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num,
|
||||||
op_type_);
|
op_type_);
|
||||||
}
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -25,7 +25,6 @@ from mindspore._c_expression import MSContext
|
||||||
from mindspore._checkparam import args_type_check
|
from mindspore._checkparam import args_type_check
|
||||||
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
|
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
|
||||||
_reset_auto_parallel_context
|
_reset_auto_parallel_context
|
||||||
from mindspore.parallel.mpi._mpi_config import _set_mpi_config, _get_mpi_config
|
|
||||||
|
|
||||||
__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
|
__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
|
||||||
'get_auto_parallel_context', 'reset_auto_parallel_context']
|
'get_auto_parallel_context', 'reset_auto_parallel_context']
|
||||||
|
@ -608,40 +607,3 @@ def get_context(attr_key):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Get context keyword %s is not recognized!" % attr_key)
|
"Get context keyword %s is not recognized!" % attr_key)
|
||||||
return getattr(_context(), attr_key)
|
return getattr(_context(), attr_key)
|
||||||
|
|
||||||
@args_type_check(enable_mpi=bool)
|
|
||||||
def set_mpi_config(**kwargs):
|
|
||||||
"""
|
|
||||||
Sets mpi config for running environment.
|
|
||||||
|
|
||||||
mpi config should be configured before running your program. If there is no configuration,
|
|
||||||
mpi moudle will be disabled by default.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
Attribute name is required for setting attributes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
enable_mpi (bool): Whether to enable mpi. Default: False.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If input key is not an attribute in mpi config.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> mpiconfig.set_mpi_config(enable_mpi=True)
|
|
||||||
"""
|
|
||||||
_set_mpi_config(**kwargs)
|
|
||||||
|
|
||||||
def get_mpi_config(attr_key):
|
|
||||||
"""
|
|
||||||
Gets mpi config attribute value according to the input key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attr_key (str): The key of the attribute.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Object, The value of given attribute key.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If input key is not an attribute in context.
|
|
||||||
"""
|
|
||||||
return _get_mpi_config(attr_key)
|
|
||||||
|
|
|
@ -104,7 +104,7 @@ def _get_mpi_config(attr_key):
|
||||||
Object, The value of given attribute key.
|
Object, The value of given attribute key.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If input key is not an attribute in context.
|
ValueError: If input key is not an attribute in config.
|
||||||
"""
|
"""
|
||||||
if not hasattr(_mpi_config(), attr_key):
|
if not hasattr(_mpi_config(), attr_key):
|
||||||
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
|
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
|
||||||
|
|
Loading…
Reference in New Issue