support host reduce scatter and allgather

This commit is contained in:
chenjianping 2020-06-04 17:20:47 +08:00
parent 5306172fee
commit af8108c9e1
17 changed files with 597 additions and 20 deletions

View File

@ -86,7 +86,7 @@ checkopts()
ENABLE_DUMPE2E="off"
ENABLE_DUMP_IR="on"
COMPILE_MINDDATA="on"
ENABLE_MPI="on"
ENABLE_MPI="off"
CUDA_VERSION="9.2"
COMPILE_PREDICT="off"
USE_GLOG="on"
@ -168,6 +168,7 @@ checkopts()
if [[ "X$OPTARG" == "Xgpu" ]]; then
ENABLE_GPU="on"
ENABLE_CPU="on"
ENABLE_MPI="on"
elif [[ "X$OPTARG" == "Xd" || "X$OPTARG" == "Xascend" ]]; then
ENABLE_D="on"
ENABLE_CPU="on"

View File

@ -26,6 +26,9 @@ include_directories(${Python3_INCLUDE_DIRS})
include_directories(${CMAKE_SOURCE_DIR}/third_party)
if (ENABLE_CPU)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/mkl_dnn.cmake)
if (ENABLE_MPI)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/ompi.cmake)
endif()
endif()
if (ENABLE_GPU)
@ -36,7 +39,6 @@ if (ENABLE_GPU)
if (ENABLE_MPI)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/nccl.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/ompi.cmake)
endif()
endif()

View File

@ -109,19 +109,20 @@ if (ENABLE_CPU)
)
endif ()
if (ENABLE_GPU)
if (ENABLE_MPI)
install(
TARGETS _ms_mpi
DESTINATION ${INSTALL_BASE_DIR}
COMPONENT mindspore
)
endif ()
if (ENABLE_GPU)
install(
TARGETS gpu_collective
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
endif ()
install(
TARGETS gpu_queue
DESTINATION ${INSTALL_LIB_DIR}

View File

@ -8,6 +8,10 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows")
add_compile_definitions(BUILDING_DLL)
endif()
if (ENABLE_MPI)
add_compile_definitions(ENABLE_MPI)
endif ()
if(ENABLE_GPU)
find_package(CUDA REQUIRED)
find_package(Threads)
@ -120,7 +124,11 @@ endforeach ()
set_property(SOURCE ${SUB_OBJECTS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME)
add_library(mindspore STATIC ${SUB_OBJECTS_SRC})
target_link_libraries(mindspore proto_input)
if (ENABLE_CPU AND ENABLE_MPI)
target_link_libraries(mindspore securec mindspore::flatbuffers mindspore::ompi)
else ()
target_link_libraries(mindspore securec mindspore::flatbuffers)
endif ()
if (NOT WIN32)
target_link_libraries(mindspore dl)
endif()

View File

@ -14,6 +14,15 @@ endif ()
if (ENABLE_CPU)
file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
if (ENABLE_MPI)
# _ms_mpi
set_property(SOURCE "gpu/mpi/mpi_initializer.cc"
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc")
target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi)
else ()
list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc")
endif ()
endif ()
# gpu
@ -39,11 +48,6 @@ if (ENABLE_GPU)
set_property(SOURCE ${GPU_COLLECTIVE_SRCS}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
add_library(gpu_collective SHARED ${GPU_COLLECTIVE_SRCS})
# _ms_mpi
set_property(SOURCE "gpu/mpi/mpi_initializer.cc"
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc")
target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi)
target_link_libraries(gpu_collective PRIVATE mindspore::ompi mindspore::nccl)
endif ()

View File

@ -15,7 +15,6 @@
*/
#include "device/ascend/ascend_kernel_runtime.h"
#include <string>
#include <vector>
#include <memory>
@ -24,6 +23,7 @@
#include <algorithm>
#include "device/ascend/ascend_device_address.h"
#include "device/cpu/mpi/mpi_adapter.h"
#include "utils/context/ms_context.h"
#include "device/ascend/profiling/profiling_manager.h"
#include "hccl/hcom.h"
@ -510,11 +510,19 @@ bool AscendKernelRuntime::HcclInit() {
MS_LOG(ERROR) << "file path " << config_path_str << " does not exist";
return false;
}
#ifdef ENABLE_MPI
int rank_id = device::cpu::MPIAdapter::Instance().GetRankId();
const char *offset = std::getenv("RANK_OFFSET");
if (offset != nullptr) {
int rank_offset = std::stoi(offset);
rank_id += rank_offset;
}
const char *identify = reinterpret_cast<const char *>(std::to_string(rank_id).c_str());
#else
const char *identify = std::getenv("RANK_ID");
#endif
if (identify == nullptr) {
MS_LOG(ERROR) << "get hccl rankid failed, please set env RANK_ID";
free(full_path);
return false;
}
MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << identify;

View File

@ -0,0 +1,191 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "device/cpu/mpi/mpi_adapter.h"
#include <algorithm>
#include "utils/log_adapter.h"
namespace mindspore {
namespace device {
namespace cpu {
namespace {
MPI_Op GetMpiOp(const std::string &op_type) {
if (op_type == "sum") {
return MPI_SUM;
} else if (op_type == "max") {
return MPI_MAX;
} else if (op_type == "min") {
return MPI_MIN;
} else if (op_type == "prod") {
return MPI_PROD;
}
MS_LOG(EXCEPTION) << "unsupport op_type:" << op_type;
return MPI_SUM;
}
} // namespace
MPIAdapter::MPIAdapter() : rank_id_(0), rank_size_(0), comm_group_world_(MPI_GROUP_NULL) { Init(); }
MPIAdapter::~MPIAdapter() {
for (auto iter = ranks_group_.begin(); iter != ranks_group_.end(); ++iter) {
MPI_Group_free(&iter->second);
}
if (comm_group_world_ != MPI_GROUP_NULL) {
MPI_Group_free(&comm_group_world_);
}
int finalized;
MPI_Finalized(&finalized);
if (finalized == 0) {
MPI_Finalize();
}
}
MPIAdapter &MPIAdapter::Instance() {
static MPIAdapter instance;
return instance;
}
int MPIAdapter::GetRankId() const { return rank_id_; }
void MPIAdapter::Init() {
static bool init = false;
if (init) {
return;
}
int init_flag = 0;
if (MPI_Initialized(&init_flag) != MPI_SUCCESS) {
MS_LOG(EXCEPTION) << "Check mpi initialized fail!";
}
if (init_flag == 0) {
auto ret = MPI_Init(nullptr, nullptr);
if (ret != MPI_SUCCESS) {
MS_LOG(EXCEPTION) << "Failed to init mpi!";
}
}
MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_);
if (comm_group_world_ == MPI_GROUP_NULL) {
MS_LOG(EXCEPTION) << "comm_group_world_ init fail!";
}
auto ret = MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_);
if (ret != MPI_SUCCESS) {
MS_LOG(EXCEPTION) << "Failed to init mpi rank id!";
}
ret = MPI_Comm_size(MPI_COMM_WORLD, &rank_size_);
if (ret != MPI_SUCCESS) {
MS_LOG(EXCEPTION) << "Failed to init mpi rank size!rankid:" << rank_id_;
}
init = true;
}
MPI_Group MPIAdapter::AddGroup(const std::vector<int> &ranks) {
if (ranks.size() > static_cast<size_t>(rank_size_) || ranks.empty()) {
MS_LOG(EXCEPTION) << "input rank size: " << ranks.size() << ", max rank size: " << rank_size_;
}
if (std::find(ranks.begin(), ranks.end(), rank_id_) == ranks.end()) {
MS_LOG(ERROR) << "rankid:" << rank_id_ << " is not in the input group.";
return MPI_GROUP_NULL;
}
std::lock_guard<std::mutex> lock(group_mutex_);
auto iter = ranks_group_.find(ranks);
if (iter != ranks_group_.end()) {
return iter->second;
}
const auto ranks_size = ranks.size();
std::vector<int> ranks_input(ranks_size, 0);
for (size_t i = 0; i < ranks_size; ++i) {
ranks_input[i] = ranks[i];
}
MPI_Group group = MPI_GROUP_NULL;
MPI_Group_incl(comm_group_world_, ranks.size(), ranks_input.data(), &group);
if (group == MPI_GROUP_NULL) {
MS_LOG(EXCEPTION) << "create mpi group fail!rankid:" << rank_id_;
}
ranks_group_[ranks] = group;
MS_LOG(INFO) << "rank:" << rank_id_ << " add group:" << group;
return group;
}
bool MPIAdapter::ReduceScatter(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
const std::string &op_type) {
if (ranks_group.empty()) {
MS_LOG(ERROR) << "input rank group is empty!";
return false;
}
auto group = AddGroup(ranks_group);
if (group == MPI_GROUP_NULL) {
MS_LOG(EXCEPTION) << "Get mpi group fail!rankid:" << rank_id_;
}
MPI_Comm comm;
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
if (comm == MPI_COMM_NULL) {
MS_LOG(EXCEPTION) << "create mpi comm fail!rankid:" << rank_id_;
}
std::vector<int> receive_count(ranks_group.size(), 0);
for (size_t i = 0; i < ranks_group.size(); ++i) {
receive_count[i] = data_num;
}
auto op = GetMpiOp(op_type);
auto ret = MPI_Reduce_scatter(input, output, receive_count.data(), MPI_FLOAT, op, comm);
bool result = true;
if (ret != MPI_SUCCESS) {
MS_LOG(ERROR) << "mpi reduce_scatter fail!ret = " << ret << ", rankid:" << rank_id_;
result = false;
}
ret = MPI_Comm_free(&comm);
if (ret != MPI_SUCCESS) {
MS_LOG(WARNING) << "mpi comm free fail! ret = " << ret << ", rankid:" << rank_id_;
}
return result;
}
bool MPIAdapter::AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
if (ranks_group.empty()) {
MS_LOG(ERROR) << "input rank group is empty!";
return false;
}
auto group = AddGroup(ranks_group);
if (group == MPI_GROUP_NULL) {
MS_LOG(EXCEPTION) << "Get mpi group fail! rankid:" << rank_id_;
}
MPI_Comm comm;
MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm);
if (comm == MPI_COMM_NULL) {
MS_LOG(EXCEPTION) << "create mpi comm fail! rankid:" << rank_id_;
}
auto ret = MPI_Allgather(input, data_num, MPI_FLOAT, output, data_num, MPI_FLOAT, comm);
bool result = true;
if (ret != MPI_SUCCESS) {
MS_LOG(ERROR) << "mpi allgater fail!ret = " << ret << ", rankid:" << rank_id_;
result = false;
}
ret = MPI_Comm_free(&comm);
if (ret != MPI_SUCCESS) {
MS_LOG(WARNING) << "mpi comm free fail!ret = " << ret << ",rankid:" << rank_id_;
}
return result;
}
} // namespace cpu
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DEVICE_CPU_MPI_MPI_ADAPTER_H_
#define MINDSPORE_CCSRC_DEVICE_CPU_MPI_MPI_ADAPTER_H_
#ifdef ENABLE_MPI
#include <mpi.h>
#include <vector>
#include <map>
#include <string>
#include <mutex>
namespace mindspore {
namespace device {
namespace cpu {
constexpr auto kOpTypeSum = "sum";
class MPIAdapter {
public:
~MPIAdapter();
static MPIAdapter &Instance();
int GetRankId() const;
bool ReduceScatter(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
const std::string &op_type = kOpTypeSum);
bool AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);
private:
MPIAdapter();
void Init();
MPI_Group AddGroup(const std::vector<int> &ranks);
int rank_id_;
int rank_size_;
MPI_Group comm_group_world_;
// key:ranks group, value: mpi group
std::map<std::vector<int>, MPI_Group> ranks_group_;
std::mutex group_mutex_;
};
} // namespace cpu
} // namespace device
} // namespace mindspore
#endif // ENABLE_MPI
#endif // MINDSPORE_CCSRC_DEVICE_CPU_MPI_MPI_ADAPTER_H_

View File

@ -20,7 +20,6 @@
namespace mindspore {
namespace device {
namespace gpu {
class CollectiveFakeInitializer {
public:
CollectiveFakeInitializer() = default;

View File

@ -24,10 +24,28 @@ namespace mindspore {
namespace device {
namespace gpu {
MPIInitializer::MPIInitializer() {
int init_flag = 0;
if (MPI_Initialized(&init_flag) != MPI_SUCCESS) {
return;
}
if (init_flag == 0) {
auto ret = MPI_Init(nullptr, nullptr);
if (ret != MPI_SUCCESS) {
return;
}
}
MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_);
MPI_Comm_size(MPI_COMM_WORLD, &rank_size_);
}
MPIInitializer::~MPIInitializer() {
int finalized_flag = 0;
(void)MPI_Finalized(&finalized_flag);
if (finalized_flag == 0) {
(void)MPI_Finalize();
}
}
MPIInitializer &MPIInitializer::GetInstance() {
static MPIInitializer instance;
return instance;

View File

@ -30,7 +30,7 @@ class MPIInitializer {
private:
MPIInitializer();
~MPIInitializer() = default;
~MPIInitializer();
int rank_id_;
int rank_size_;

View File

@ -21,6 +21,11 @@ if (ENABLE_CPU)
file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"cpu/*.cc"
)
if (NOT ENABLE_MPI)
list(REMOVE_ITEM CPU_SRC_LIST "cpu/allgather_cpu_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/reduce_scatter_cpu_kernel.cc")
endif ()
endif ()
if (ENABLE_GPU)

View File

@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/cpu/allgather_cpu_kernel.h"
#include "device/cpu/cpu_device_address.h"
#include "device/cpu/mpi/mpi_adapter.h"
#include "ir/primitive.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr auto kRanksGroup = "group";
constexpr auto kAllGatherInputNum = 1;
} // namespace
AllGatherCPUKernel::AllGatherCPUKernel() : input_data_number_(0) {}
void AllGatherCPUKernel::InitKernel(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != kAllGatherInputNum) {
MS_LOG(EXCEPTION) << "allgather input num:" << input_num;
}
for (size_t i = 0; i < input_num; ++i) {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
size_t count = 1;
for (size_t j = 0; j < shape.size(); j++) {
count *= IntToSize(shape[j]);
}
input_data_number_ += count;
}
auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup);
if (ranks_group != nullptr) {
ranks_group_ = GetValue<std::vector<int>>(ranks_group);
} else {
MS_LOG(EXCEPTION) << "Miss attribute " << kRanksGroup;
}
}
bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
return device::cpu::MPIAdapter::Instance().AllGather(input_addr, output_addr, ranks_group_, input_data_number_);
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "kernel/cpu/cpu_kernel.h"
#include "kernel/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class AllGatherCPUKernel : public CPUKernel {
public:
AllGatherCPUKernel();
~AllGatherCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
size_t input_data_number_;
std::vector<int> ranks_group_;
};
MS_REG_CPU_KERNEL(HostAllGather, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
AllGatherCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_

View File

@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/cpu/reduce_scatter_cpu_kernel.h"
#include "device/cpu/cpu_device_address.h"
#include "device/cpu/mpi/mpi_adapter.h"
#include "ir/primitive.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr auto kRanksGroup = "group";
} // namespace
ReduceScatterCPUKernel::ReduceScatterCPUKernel() : output_data_number_(0), op_type_(device::cpu::kOpTypeSum) {}
void ReduceScatterCPUKernel::InitKernel(const CNodePtr &kernel_node) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t i = 0; i < output_num; ++i) {
auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i);
size_t size = 1;
for (size_t j = 0; j < shape.size(); j++) {
size *= IntToSize(shape[j]);
}
output_data_number_ += size;
}
auto op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op");
if (op != nullptr) {
op_type_ = GetValue<std::string>(op);
}
auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup);
if (ranks_group != nullptr) {
ranks_group_ = GetValue<std::vector<int>>(ranks_group);
} else {
MS_LOG(EXCEPTION) << "Miss attribute " << kRanksGroup;
}
}
bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
return device::cpu::MPIAdapter::Instance().ReduceScatter(input_addr, output_addr, ranks_group_, output_data_number_,
op_type_);
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_
#include <vector>
#include <string>
#include "kernel/cpu/cpu_kernel.h"
#include "kernel/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class ReduceScatterCPUKernel : public CPUKernel {
public:
ReduceScatterCPUKernel();
~ReduceScatterCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
size_t output_data_number_;
std::string op_type_;
std::vector<int> ranks_group_;
};
MS_REG_CPU_KERNEL(HostReduceScatter, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReduceScatterCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_

View File

@ -0,0 +1,70 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
import mindspore._ms_mpi as mpi
# run comand:
# mpirun -np 3 python test_reduce_scatter.py
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.op = "sum"
self.reducescatter = P.HostReduceScatter(op=self.op, group=[0,1,2])
def construct(self, x):
return self.reducescatter(x)
class AllGatherNet(nn.Cell):
def __init__(self):
super(AllGatherNet, self).__init__()
self.hostallgather = P.HostAllGather(group=(0, 1, 2))
def construct(self, x):
return self.hostallgather(x)
def test_net_reduce_scatter():
x = np.ones(12).astype(np.float32) * 0.1
reducescatter = Net()
rankid = mpi.get_rank_id()
print("self rankid:", rankid)
output = reducescatter(Tensor(x, mstype.float32))
print("output:\n", output)
expect_result = np.ones(4).astype(np.float32) * 0.3
diff = abs(output.asnumpy() - expect_result)
error = np.ones(shape=expect_result.shape) * 1.0e-6
assert np.all(diff < error)
allgather = AllGatherNet()
allgather_output = allgather(output)
print("allgather result:\n", allgather_output)
expect_allgather_result = np.ones(12).astype(np.float32) * 0.3
diff = abs(allgather_output.asnumpy() - expect_allgather_result)
error = np.ones(shape=expect_allgather_result.shape) * 1.0e-6
assert np.all(diff < error)
if __name__ == '__main__':
test_net_reduce_scatter()