!41495 fix cpu/gpu communication tensor

Merge pull request !41495 from 王禹程/fix_cpu
This commit is contained in:
i-robot 2022-09-06 01:56:06 +00:00 committed by Gitee
commit 28bc5f3328
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 48 additions and 15 deletions

View File

@ -172,6 +172,8 @@ bool Somas::Assign(const KernelGraphPtr &graph_ptr) {
size_t Somas::GetCommunicationReservedSize() const { return 0; }
void Somas::CommunicationTensorProcess(const std::vector<SomasTensorPtr> &tensors) const {}
bool Somas::GetEnableCacheFlag(const session::KernelGraph &graph) const {
return graph.execution_order().size() >= kCachedResultThreshold;
}
@ -1118,14 +1120,7 @@ void Somas::CommunicationNodeProcess() {
// Contiguous input
if ((!node->input_tensors_.empty()) && (!node->input_tensors_[0]->contiguous_)) {
// add gap for first and last input
if (node->input_tensors_[0]->aligned_size_ != 0) {
node->input_tensors_[0]->aligned_size_ += communication_gap_size_;
}
if (node->input_tensors_[node->input_tensors_.size() - 1]->aligned_size_ != 0) {
node->input_tensors_[node->input_tensors_.size() - 1]->aligned_size_ += communication_gap_size_;
}
CommunicationTensorProcess(node->input_tensors_);
std::vector<size_t> inputs;
for (const auto &input_tensor : node->input_tensors_) {
MS_EXCEPTION_IF_NULL(input_tensor);
@ -1142,13 +1137,7 @@ void Somas::CommunicationNodeProcess() {
// Contiguous output
if ((!node->output_tensors_.empty()) && (!node->output_tensors_[0]->contiguous_)) {
// add gap for first and last output
if (node->output_tensors_[0]->aligned_size_ != 0) {
node->output_tensors_[0]->aligned_size_ += communication_gap_size_;
}
if (node->output_tensors_[node->output_tensors_.size() - 1]->aligned_size_ != 0) {
node->output_tensors_[node->output_tensors_.size() - 1]->aligned_size_ += communication_gap_size_;
}
CommunicationTensorProcess(node->output_tensors_);
std::vector<size_t> outputs;
for (const auto &output_tensor : node->output_tensors_) {

View File

@ -104,6 +104,7 @@ class Somas {
virtual bool InitDevSpecControlTensors(const session::KernelGraph &graph) = 0;
virtual bool DevSpecNodeProcess(const session::KernelGraph &graph) = 0;
virtual void CommunicationTensorProcess(const std::vector<SomasTensorPtr> &tensors) const;
// end
// SOMAS Configuration

View File

@ -217,6 +217,16 @@ void AscendSomas::NonTaskSplitProcess(const session::KernelGraph &graph) {
}
}
}
void AscendSomas::CommunicationTensorProcess(const std::vector<somas::SomasTensorPtr> &tensors) const {
// add gap for first and last input
if (tensors[0]->aligned_size_ != 0) {
tensors[0]->aligned_size_ += GetCommunicationReservedSize();
}
if (tensors[tensors.size() - 1]->aligned_size_ != 0) {
tensors[tensors.size() - 1]->aligned_size_ += GetCommunicationReservedSize();
}
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -39,6 +39,7 @@ class AscendSomas : public somas::Somas {
bool Initialize() override;
string GetDeviceName() const override;
size_t GetCommunicationReservedSize() const override;
void CommunicationTensorProcess(const std::vector<somas::SomasTensorPtr> &tensors) const override;
size_t GetAlignSize(size_t original_size) const override;
bool GetDependExecOrderFlag(const session::KernelGraph &graph) const override;

View File

@ -36,6 +36,20 @@ bool CPUSomas::GetDependExecOrderFlag(const session::KernelGraph &graph) const {
bool CPUSomas::InitDevSpecControlTensors(const session::KernelGraph &graph) { return true; }
bool CPUSomas::DevSpecNodeProcess(const session::KernelGraph &graph) { return true; }
void CPUSomas::CommunicationTensorProcess(const std::vector<somas::SomasTensorPtr> &tensors) const {
size_t all_communication_size = 0;
for (auto &tensor : tensors) {
tensor->aligned_size_ = tensor->GetOriginalSize();
if (tensor->aligned_size_ != 0) {
MS_LOG(ERROR) << "The size of communication tensor is zero.";
}
all_communication_size += tensor->aligned_size_;
}
auto aligned_communication_size = GetAlignSize(all_communication_size);
auto need_aligned = aligned_communication_size - all_communication_size;
tensors[tensors.size() - 1]->aligned_size_ += need_aligned;
}
} // namespace cpu
} // namespace device
} // namespace mindspore

View File

@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_plugin_DEVICE_CPU_HAL_HARDWARE_CPU_SOMAS_H__
#include <string>
#include <vector>
#include "backend/common/somas/somas.h"
#include "runtime/hardware/device_type.h"
@ -30,6 +31,7 @@ class CPUSomas : public somas::Somas {
bool Initialize() override;
string GetDeviceName() const override;
size_t GetAlignSize(size_t original_size) const override;
void CommunicationTensorProcess(const std::vector<somas::SomasTensorPtr> &tensors) const override;
bool GetDependExecOrderFlag(const session::KernelGraph &graph) const override;
bool InitDevSpecControlTensors(const session::KernelGraph &graph) override;

View File

@ -136,6 +136,20 @@ bool GPUSomas::InplaceNodeProcess(const session::KernelGraph &graph) {
}
return true;
}
void GPUSomas::CommunicationTensorProcess(const std::vector<somas::SomasTensorPtr> &tensors) const {
size_t all_communication_size = 0;
for (auto &tensor : tensors) {
tensor->aligned_size_ = tensor->GetOriginalSize();
if (tensor->aligned_size_ != 0) {
MS_LOG(ERROR) << "The size of communication tensor is zero.";
}
all_communication_size += tensor->aligned_size_;
}
auto aligned_communication_size = GetAlignSize(all_communication_size);
auto need_aligned = aligned_communication_size - all_communication_size;
tensors[tensors.size() - 1]->aligned_size_ += need_aligned;
}
} // namespace gpu
} // namespace device
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include <map>
#include <string>
#include <vector>
#include "backend/common/somas/somas.h"
#include "runtime/hardware/device_type.h"
@ -32,6 +33,7 @@ class GPUSomas : public somas::Somas {
bool Initialize() override;
string GetDeviceName() const override;
size_t GetAlignSize(size_t original_size) const override;
void CommunicationTensorProcess(const std::vector<somas::SomasTensorPtr> &tensors) const override;
bool GetDependExecOrderFlag(const session::KernelGraph &graph) const override;
bool InitDevSpecControlTensors(const session::KernelGraph &graph) override;