forked from OSSInnovation/mindspore
!4170 dlopen cpu mpi adapter
Merge pull request !4170 from kisnwang/disable-host-mpi
This commit is contained in:
commit
234a3fd431
|
@ -140,15 +140,6 @@ if (ENABLE_MPI)
|
|||
COMPONENT mindspore
|
||||
)
|
||||
endif ()
|
||||
file(GLOB_RECURSE MPI_LIB_LIST
|
||||
${ompi_LIBPATH}/libmpi${CMAKE_SHARED_LIBRARY_SUFFIX}*
|
||||
${ompi_LIBPATH}/libopen*${CMAKE_SHARED_LIBRARY_SUFFIX}*
|
||||
)
|
||||
install(
|
||||
FILES ${MPI_LIB_LIST}
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (ENABLE_GPU)
|
||||
|
|
|
@ -155,11 +155,7 @@ if (ENABLE_DEBUGGER)
|
|||
endif()
|
||||
|
||||
target_link_libraries(mindspore proto_input)
|
||||
if (ENABLE_MPI AND ENABLE_CPU)
|
||||
target_link_libraries(mindspore securec mindspore::flatbuffers mpi_adapter)
|
||||
else ()
|
||||
target_link_libraries(mindspore securec mindspore::flatbuffers)
|
||||
endif ()
|
||||
target_link_libraries(mindspore securec mindspore::flatbuffers)
|
||||
|
||||
if (NOT WIN32)
|
||||
target_link_libraries(mindspore dl)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
#include "backend/kernel_compiler/cpu/allgather_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_adapter.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_interface.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -45,9 +45,7 @@ bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
auto input_data_num = inputs[0]->size / sizeof(float);
|
||||
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||
return mpi_instance->AllGather(input_addr, output_addr, ranks_group_, input_data_num);
|
||||
return MPIAllGather(input_addr, output_addr, ranks_group_, input_data_num);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
#include <thread>
|
||||
#include "backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_adapter.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_interface.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -49,11 +49,8 @@ bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector<kernel::AddressP
|
|||
const std::vector<int> &rank_group = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
size_t input_split_lens = input_size / split_num_ / sizeof(float_t);
|
||||
size_t output_split_lens = output_size / split_num_ / sizeof(float_t);
|
||||
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||
for (int i = 0; i < split_num_; i++) {
|
||||
mpi_instance->AllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group,
|
||||
input_split_lens);
|
||||
MPIAllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group, input_split_lens);
|
||||
}
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto end_time = std::chrono::steady_clock::now();
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
#include "backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_adapter.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_interface.h"
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -24,7 +24,7 @@ namespace {
|
|||
constexpr auto kRanksGroup = "group";
|
||||
} // namespace
|
||||
|
||||
ReduceScatterCPUKernel::ReduceScatterCPUKernel() : op_type_(device::cpu::kOpTypeSum) {}
|
||||
ReduceScatterCPUKernel::ReduceScatterCPUKernel() : op_type_(kMPIOpTypeSum) {}
|
||||
|
||||
void ReduceScatterCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
auto op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op");
|
||||
|
@ -46,9 +46,7 @@ bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
|
|||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
auto output_data_num = outputs[0]->size / sizeof(float);
|
||||
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||
return mpi_instance->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, op_type_);
|
||||
return MPIReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, op_type_);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,12 +14,12 @@ endif ()
|
|||
|
||||
if (ENABLE_CPU)
|
||||
file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc")
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc", "cpu/mpi/mpi_export.cc")
|
||||
endif ()
|
||||
|
||||
if (ENABLE_MPI)
|
||||
if (ENABLE_CPU)
|
||||
file(GLOB_RECURSE MPI_SRC_LIST "cpu/mpi/mpi_adapter.cc")
|
||||
file(GLOB_RECURSE MPI_SRC_LIST "cpu/mpi/mpi_adapter.cc", "cpu/mpi/mpi_export.cc")
|
||||
set_property(SOURCE ${MPI_SRC_LIST}
|
||||
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
||||
add_library(mpi_adapter SHARED ${MPI_SRC_LIST})
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include <exception>
|
||||
#include <algorithm>
|
||||
#include "runtime/device/ascend/ascend_device_address.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_adapter.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_interface.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/context/context_extends.h"
|
||||
#include "utils/mpi/mpi_config.h"
|
||||
|
@ -64,9 +64,7 @@ std::string GetRankId() {
|
|||
auto mpi_config_ptr = MpiConfig::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
|
||||
if (mpi_config_ptr->enable_mpi()) {
|
||||
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||
int rank_id = mpi_instance->GetRankId();
|
||||
int rank_id = GetMPIRankId();
|
||||
const char *offset = std::getenv("RANK_OFFSET");
|
||||
if (offset != nullptr) {
|
||||
try {
|
||||
|
|
|
@ -14,11 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "runtime/device/cpu/mpi/mpi_adapter.h"
|
||||
#ifdef ENABLE_MPI
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "pybind11/pybind11.h"
|
||||
#endif // ENABLE_MPI
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -33,8 +33,6 @@ std::shared_ptr<MPIAdapter> MPIAdapter::Instance() {
|
|||
return instance_;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MPI
|
||||
|
||||
#define RAISE_EXCEPTION(message) \
|
||||
{ \
|
||||
std::ostringstream oss; \
|
||||
|
@ -271,7 +269,6 @@ bool MPIAdapter::AllGather(const float *input, float *output, const std::vector<
|
|||
}
|
||||
return true;
|
||||
}
|
||||
#endif // ENABLE_MPI
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,13 +16,11 @@
|
|||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_ADAPTER_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_ADAPTER_H_
|
||||
#ifdef ENABLE_MPI
|
||||
#include <mpi.h>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <mutex>
|
||||
#endif // ENABLE_MPI
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -31,27 +29,19 @@ namespace cpu {
|
|||
#ifndef FUNC_EXPORT
|
||||
#define FUNC_EXPORT __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
constexpr auto kOpTypeSum = "sum";
|
||||
class MPIAdapter {
|
||||
public:
|
||||
FUNC_EXPORT static std::shared_ptr<MPIAdapter> Instance();
|
||||
FUNC_EXPORT int GetRankId() const { return rank_id_; }
|
||||
FUNC_EXPORT int GetRankSize() const { return rank_size_; }
|
||||
#ifdef ENABLE_MPI
|
||||
FUNC_EXPORT ~MPIAdapter();
|
||||
FUNC_EXPORT bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group,
|
||||
size_t data_num, const std::string &op_type = kOpTypeSum);
|
||||
size_t data_num, const std::string &op_type);
|
||||
FUNC_EXPORT bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t in_data_num,
|
||||
size_t output_size, const std::string &op_type = kOpTypeSum,
|
||||
float *output = nullptr);
|
||||
size_t output_size, const std::string &op_type, float *output);
|
||||
FUNC_EXPORT bool AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);
|
||||
#else
|
||||
FUNC_EXPORT ~MPIAdapter() = default;
|
||||
#endif // ENABLE_MPI
|
||||
|
||||
private:
|
||||
#ifdef ENABLE_MPI
|
||||
MPIAdapter();
|
||||
void Init();
|
||||
MPI_Group AddGroup(const std::vector<int> &ranks);
|
||||
|
@ -60,9 +50,6 @@ class MPIAdapter {
|
|||
// key:ranks group, value: mpi group
|
||||
std::map<std::vector<int>, MPI_Group> ranks_group_;
|
||||
std::mutex group_mutex_;
|
||||
#else
|
||||
MPIAdapter() = default;
|
||||
#endif // ENABLE_MPI
|
||||
int rank_id_{-1};
|
||||
int rank_size_{0};
|
||||
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "runtime/device/cpu/mpi/mpi_export.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "runtime/device/cpu/mpi/mpi_adapter.h"
|
||||
|
||||
int GetMPIRankId() {
|
||||
auto inst = mindspore::device::cpu::MPIAdapter::Instance();
|
||||
if (inst == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
return inst->GetRankId();
|
||||
}
|
||||
|
||||
int GetMPIRankSize() {
|
||||
auto inst = mindspore::device::cpu::MPIAdapter::Instance();
|
||||
if (inst == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
return inst->GetRankSize();
|
||||
}
|
||||
|
||||
bool MPIReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
|
||||
const std::string &op_type) {
|
||||
auto inst = mindspore::device::cpu::MPIAdapter::Instance();
|
||||
if (inst == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return inst->ReduceScatter(input, output, ranks_group, data_num, op_type);
|
||||
}
|
||||
|
||||
bool MPIReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t in_data_num,
|
||||
size_t output_size, const std::string &op_type, float *output) {
|
||||
auto inst = mindspore::device::cpu::MPIAdapter::Instance();
|
||||
if (inst == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return inst->ReduceScatterOverwriteInput(input, ranks_group, in_data_num, output_size, op_type, output);
|
||||
}
|
||||
|
||||
bool MPIAllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
|
||||
auto inst = mindspore::device::cpu::MPIAdapter::Instance();
|
||||
if (inst == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return inst->AllGather(input, output, ranks_group, data_num);
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_EXPORT_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_EXPORT_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#ifndef FUNC_EXPORT
|
||||
#define FUNC_EXPORT __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
extern "C" FUNC_EXPORT FUNC_EXPORT int GetMPIRankId();
|
||||
extern "C" FUNC_EXPORT FUNC_EXPORT int GetMPIRankSize();
|
||||
extern "C" FUNC_EXPORT bool MPIReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group,
|
||||
size_t data_num, const std::string &op_type);
|
||||
extern "C" FUNC_EXPORT bool MPIReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group,
|
||||
size_t in_data_num, size_t output_size,
|
||||
const std::string &op_type, float *output);
|
||||
extern "C" FUNC_EXPORT bool MPIAllGather(const float *input, float *output, const std::vector<int> &ranks_group,
|
||||
size_t data_num);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_EXPORT_H_
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "runtime/device/cpu/mpi/mpi_interface.h"
|
||||
#ifdef ENABLE_MPI
|
||||
#include <dlfcn.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
inline void *LoadLibrary(const char *name) {
|
||||
auto handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL);
|
||||
if (handle == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Load lib " << name << " failed, make sure you have installed it!";
|
||||
}
|
||||
return handle;
|
||||
}
|
||||
|
||||
inline void *GetMPIAdapterHandle() {
|
||||
static void *handle = LoadLibrary("mpi_adapter.so");
|
||||
return handle;
|
||||
}
|
||||
|
||||
inline void *GetMPIAdapterFunc(const char *name) {
|
||||
static void *handle = GetMPIAdapterHandle();
|
||||
if (handle == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Load lib " << name << " failed, make sure you have installed it!";
|
||||
}
|
||||
void *func = dlsym(handle, name);
|
||||
if (func == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Load func " << name << " failed, make sure you have implied it!";
|
||||
}
|
||||
return func;
|
||||
}
|
||||
|
||||
typedef int (*GetMPIRankIdFunc)();
|
||||
typedef int (*GetMPIRankSizeFunc)();
|
||||
typedef bool (*MPIReduceScatterFunc)(const float *input, float *output, const std::vector<int> &ranks_group,
|
||||
size_t data_num, const std::string &op_type);
|
||||
typedef bool (*MPIReduceScatterOverwriteInputFunc)(float *input, const std::vector<int> &ranks_group,
|
||||
size_t in_data_num, size_t output_size, const std::string &op_type,
|
||||
float *output);
|
||||
typedef bool (*MPIAllGatherFunc)(const float *input, float *output, const std::vector<int> &ranks_group,
|
||||
size_t data_num);
|
||||
|
||||
int GetMPIRankId() {
|
||||
static GetMPIRankIdFunc func = reinterpret_cast<GetMPIRankIdFunc>(GetMPIAdapterFunc("GetMPIRankId"));
|
||||
return func();
|
||||
}
|
||||
|
||||
int GetMPIRankSize() {
|
||||
static GetMPIRankIdFunc func = reinterpret_cast<GetMPIRankSizeFunc>(GetMPIAdapterFunc("GetMPIRankSize"));
|
||||
return func();
|
||||
}
|
||||
|
||||
bool MPIReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
|
||||
const std::string &op_type) {
|
||||
static MPIReduceScatterFunc func = reinterpret_cast<MPIReduceScatterFunc>(GetMPIAdapterFunc("MPIReduceScatter"));
|
||||
return func(input, output, ranks_group, data_num, op_type);
|
||||
}
|
||||
|
||||
bool MPIReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t in_data_num,
|
||||
size_t output_size, const std::string &op_type, float *output) {
|
||||
static MPIReduceScatterOverwriteInputFunc func =
|
||||
reinterpret_cast<MPIReduceScatterOverwriteInputFunc>(GetMPIAdapterFunc("MPIReduceScatterOverwriteInput"));
|
||||
return func(input, ranks_group, in_data_num, output_size, op_type, output);
|
||||
}
|
||||
|
||||
bool MPIAllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
|
||||
static MPIAllGatherFunc func = reinterpret_cast<MPIAllGatherFunc>(GetMPIAdapterFunc("MPIAllGather"));
|
||||
return func(input, output, ranks_group, data_num);
|
||||
}
|
||||
#endif // ENABLE_MPI
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_INTERFACE_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_INTERFACE_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#ifndef FUNC_EXPORT
|
||||
#define FUNC_EXPORT __attribute__((visibility("default")))
|
||||
#endif
|
||||
constexpr auto kMPIOpTypeSum = "sum";
|
||||
#ifdef ENABLE_MPI
|
||||
int GetMPIRankId();
|
||||
int GetMPIRankSize();
|
||||
bool MPIReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
|
||||
const std::string &op_type = kMPIOpTypeSum);
|
||||
bool MPIReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t in_data_num,
|
||||
size_t output_size, const std::string &op_type = kMPIOpTypeSum,
|
||||
float *output = nullptr);
|
||||
bool MPIAllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);
|
||||
#endif // ENABLE_MPI
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_CPU_MPI_MPI_INTERFACE_H_
|
|
@ -19,6 +19,7 @@
|
|||
#include <mpi.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
|
|
Loading…
Reference in New Issue