support host reduce

This commit is contained in:
chenjianping 2020-06-18 14:35:59 +08:00
parent b106c2204a
commit 35900037af
6 changed files with 22 additions and 37 deletions

View File

@ -179,8 +179,8 @@ bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vec
return result;
}
bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t data_num,
const std::string &op_type, float *output) {
bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t input_data_num,
size_t output_size, const std::string &op_type, float *output) {
int scatter_index = GetScatterIndex(rank_id_, ranks_group);
auto group = AddGroup(ranks_group);
if (group == MPI_GROUP_NULL) {
@ -193,7 +193,7 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
}
MPI_Win window;
auto ret = MPI_Win_create(input, data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window);
auto ret = MPI_Win_create(input, input_data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window);
if (ret != MPI_SUCCESS) {
MS_LOG(ERROR) << "mpi window create fail! ret = " << ret;
return false;
@ -205,18 +205,21 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
continue;
}
auto op = GetMpiOp(op_type);
ret = MPI_Accumulate(input + i * data_num, data_num, MPI_FLOAT, remote_rank, i * data_num, data_num, MPI_FLOAT, op,
window);
ret = MPI_Accumulate(input + i * input_data_num, input_data_num, MPI_FLOAT, remote_rank, i * input_data_num,
input_data_num, MPI_FLOAT, op, window);
if (ret != MPI_SUCCESS) {
MS_LOG(EXCEPTION) << "mpi accumulate " << op_type << " fail!ret = " << ret;
}
}
MPI_Win_fence(0, window);
if (output != nullptr) {
auto data_size = data_num * sizeof(float);
auto copy_ret = memcpy_s(output, data_size, input + scatter_index * data_num, data_size);
auto data_size = input_data_num * sizeof(float);
if (output_size < data_size) {
MS_LOG(EXCEPTION) << "output buffer size " << output_size << " < input size " << data_size;
}
auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size);
if (copy_ret != 0) {
MS_LOG(EXCEPTION) << "copy output memory fail!";
MS_LOG(EXCEPTION) << "copy output memory fail!ret = " << copy_ret;
}
}
MPI_Win_free(&window);
@ -224,7 +227,7 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
return true;
}
bool MPIAdapter::AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
bool MPIAdapter::AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
if (ranks_group.empty()) {
MS_LOG(ERROR) << "input rank group is empty!";
return false;

View File

@ -34,9 +34,10 @@ class MPIAdapter {
int GetRankId() const;
bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
const std::string &op_type = kOpTypeSum);
bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t data_num,
const std::string &op_type = kOpTypeSum, float *output = nullptr);
bool AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);
bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t input_data_num,
size_t output_size, const std::string &op_type = kOpTypeSum,
float *output = nullptr);
bool AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);
private:
MPIAdapter();

View File

@ -26,21 +26,11 @@ constexpr auto kRanksGroup = "group";
constexpr auto kAllGatherInputNum = 1;
} // namespace
AllGatherCPUKernel::AllGatherCPUKernel() : input_data_number_(0) {}
void AllGatherCPUKernel::InitKernel(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != kAllGatherInputNum) {
MS_LOG(EXCEPTION) << "allgather input num:" << input_num;
}
for (size_t i = 0; i < input_num; ++i) {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
size_t count = 1;
for (size_t j = 0; j < shape.size(); j++) {
count *= IntToSize(shape[j]);
}
input_data_number_ += count;
}
auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup);
if (ranks_group != nullptr) {
@ -55,8 +45,9 @@ bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto input_data_num = inputs[0]->size / sizeof(float);
return device::cpu::MPIAdapter::Instance().AllGather(input_addr, output_addr, ranks_group_, input_data_number_);
return device::cpu::MPIAdapter::Instance().AllGather(input_addr, output_addr, ranks_group_, input_data_num);
}
} // namespace kernel
} // namespace mindspore

View File

@ -24,7 +24,7 @@ namespace mindspore {
namespace kernel {
class AllGatherCPUKernel : public CPUKernel {
public:
AllGatherCPUKernel();
AllGatherCPUKernel() = default;
~AllGatherCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
@ -33,7 +33,6 @@ class AllGatherCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
private:
size_t input_data_number_;
std::vector<int> ranks_group_;
};

View File

@ -24,18 +24,9 @@ namespace {
constexpr auto kRanksGroup = "group";
} // namespace
ReduceScatterCPUKernel::ReduceScatterCPUKernel() : output_data_number_(0), op_type_(device::cpu::kOpTypeSum) {}
ReduceScatterCPUKernel::ReduceScatterCPUKernel() : op_type_(device::cpu::kOpTypeSum) {}
void ReduceScatterCPUKernel::InitKernel(const CNodePtr &kernel_node) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t i = 0; i < output_num; ++i) {
auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i);
size_t size = 1;
for (size_t j = 0; j < shape.size(); j++) {
size *= IntToSize(shape[j]);
}
output_data_number_ += size;
}
auto op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op");
if (op != nullptr) {
op_type_ = GetValue<std::string>(op);
@ -54,8 +45,9 @@ bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
const std::vector<kernel::AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto output_data_num = outputs[0]->size / sizeof(float);
return device::cpu::MPIAdapter::Instance().ReduceScatter(input_addr, output_addr, ranks_group_, output_data_number_,
return device::cpu::MPIAdapter::Instance().ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num,
op_type_);
}
} // namespace kernel

View File

@ -33,7 +33,6 @@ class ReduceScatterCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
private:
size_t output_data_number_;
std::string op_type_;
std::vector<int> ranks_group_;
};