forked from mindspore-Ecosystem/mindspore
support host reduce
This commit is contained in:
parent
b106c2204a
commit
35900037af
|
@ -179,8 +179,8 @@ bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vec
|
|||
return result;
|
||||
}
|
||||
|
||||
bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t data_num,
|
||||
const std::string &op_type, float *output) {
|
||||
bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t input_data_num,
|
||||
size_t output_size, const std::string &op_type, float *output) {
|
||||
int scatter_index = GetScatterIndex(rank_id_, ranks_group);
|
||||
auto group = AddGroup(ranks_group);
|
||||
if (group == MPI_GROUP_NULL) {
|
||||
|
@ -193,7 +193,7 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
|
|||
}
|
||||
|
||||
MPI_Win window;
|
||||
auto ret = MPI_Win_create(input, data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window);
|
||||
auto ret = MPI_Win_create(input, input_data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window);
|
||||
if (ret != MPI_SUCCESS) {
|
||||
MS_LOG(ERROR) << "mpi window create fail! ret = " << ret;
|
||||
return false;
|
||||
|
@ -205,18 +205,21 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
|
|||
continue;
|
||||
}
|
||||
auto op = GetMpiOp(op_type);
|
||||
ret = MPI_Accumulate(input + i * data_num, data_num, MPI_FLOAT, remote_rank, i * data_num, data_num, MPI_FLOAT, op,
|
||||
window);
|
||||
ret = MPI_Accumulate(input + i * input_data_num, input_data_num, MPI_FLOAT, remote_rank, i * input_data_num,
|
||||
input_data_num, MPI_FLOAT, op, window);
|
||||
if (ret != MPI_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "mpi accumulate " << op_type << " fail!ret = " << ret;
|
||||
}
|
||||
}
|
||||
MPI_Win_fence(0, window);
|
||||
if (output != nullptr) {
|
||||
auto data_size = data_num * sizeof(float);
|
||||
auto copy_ret = memcpy_s(output, data_size, input + scatter_index * data_num, data_size);
|
||||
auto data_size = input_data_num * sizeof(float);
|
||||
if (output_size < data_size) {
|
||||
MS_LOG(EXCEPTION) << "output buffer size " << output_size << " < input size " << data_size;
|
||||
}
|
||||
auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size);
|
||||
if (copy_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "copy output memory fail!";
|
||||
MS_LOG(EXCEPTION) << "copy output memory fail!ret = " << copy_ret;
|
||||
}
|
||||
}
|
||||
MPI_Win_free(&window);
|
||||
|
@ -224,7 +227,7 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int
|
|||
return true;
|
||||
}
|
||||
|
||||
bool MPIAdapter::AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
|
||||
bool MPIAdapter::AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) {
|
||||
if (ranks_group.empty()) {
|
||||
MS_LOG(ERROR) << "input rank group is empty!";
|
||||
return false;
|
||||
|
|
|
@ -34,9 +34,10 @@ class MPIAdapter {
|
|||
int GetRankId() const;
|
||||
bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num,
|
||||
const std::string &op_type = kOpTypeSum);
|
||||
bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t data_num,
|
||||
const std::string &op_type = kOpTypeSum, float *output = nullptr);
|
||||
bool AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);
|
||||
bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t input_data_num,
|
||||
size_t output_size, const std::string &op_type = kOpTypeSum,
|
||||
float *output = nullptr);
|
||||
bool AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num);
|
||||
|
||||
private:
|
||||
MPIAdapter();
|
||||
|
|
|
@ -26,21 +26,11 @@ constexpr auto kRanksGroup = "group";
|
|||
constexpr auto kAllGatherInputNum = 1;
|
||||
} // namespace
|
||||
|
||||
AllGatherCPUKernel::AllGatherCPUKernel() : input_data_number_(0) {}
|
||||
|
||||
void AllGatherCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != kAllGatherInputNum) {
|
||||
MS_LOG(EXCEPTION) << "allgather input num:" << input_num;
|
||||
}
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
|
||||
size_t count = 1;
|
||||
for (size_t j = 0; j < shape.size(); j++) {
|
||||
count *= IntToSize(shape[j]);
|
||||
}
|
||||
input_data_number_ += count;
|
||||
}
|
||||
|
||||
auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup);
|
||||
if (ranks_group != nullptr) {
|
||||
|
@ -55,8 +45,9 @@ bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
auto input_data_num = inputs[0]->size / sizeof(float);
|
||||
|
||||
return device::cpu::MPIAdapter::Instance().AllGather(input_addr, output_addr, ranks_group_, input_data_number_);
|
||||
return device::cpu::MPIAdapter::Instance().AllGather(input_addr, output_addr, ranks_group_, input_data_num);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
class AllGatherCPUKernel : public CPUKernel {
|
||||
public:
|
||||
AllGatherCPUKernel();
|
||||
AllGatherCPUKernel() = default;
|
||||
~AllGatherCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
@ -33,7 +33,6 @@ class AllGatherCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
size_t input_data_number_;
|
||||
std::vector<int> ranks_group_;
|
||||
};
|
||||
|
||||
|
|
|
@ -24,18 +24,9 @@ namespace {
|
|||
constexpr auto kRanksGroup = "group";
|
||||
} // namespace
|
||||
|
||||
ReduceScatterCPUKernel::ReduceScatterCPUKernel() : output_data_number_(0), op_type_(device::cpu::kOpTypeSum) {}
|
||||
ReduceScatterCPUKernel::ReduceScatterCPUKernel() : op_type_(device::cpu::kOpTypeSum) {}
|
||||
|
||||
void ReduceScatterCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i);
|
||||
size_t size = 1;
|
||||
for (size_t j = 0; j < shape.size(); j++) {
|
||||
size *= IntToSize(shape[j]);
|
||||
}
|
||||
output_data_number_ += size;
|
||||
}
|
||||
auto op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op");
|
||||
if (op != nullptr) {
|
||||
op_type_ = GetValue<std::string>(op);
|
||||
|
@ -54,8 +45,9 @@ bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
|
|||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
auto output_data_num = outputs[0]->size / sizeof(float);
|
||||
|
||||
return device::cpu::MPIAdapter::Instance().ReduceScatter(input_addr, output_addr, ranks_group_, output_data_number_,
|
||||
return device::cpu::MPIAdapter::Instance().ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num,
|
||||
op_type_);
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -33,7 +33,6 @@ class ReduceScatterCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
size_t output_data_number_;
|
||||
std::string op_type_;
|
||||
std::vector<int> ranks_group_;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue