Merge pull request !38219 from chengbin/fix_hccl_master
This commit is contained in:
i-robot 2022-07-18 11:59:55 +00:00 committed by Gitee
commit 851f52ce86
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 39 additions and 10 deletions

View File

@ -143,6 +143,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) {
common::AnfAlgo::SetNodeAttr(kAttrComm, MakeValue<int64_t>((int64_t)comm), anf_node);
}
anf_node_ = anf_node;
CalLoopSize();
return true;
}
@ -170,15 +171,11 @@ const std::vector<size_t> &HcclKernel::GetInputSizeList() const {
return mutable_input_size_list_;
}
const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
void HcclKernel::CalLoopSize() {
auto anf_node = anf_node_.lock();
if (!anf_node) {
MS_LOG(EXCEPTION) << "anf_node pointer is expired.";
}
size_t size = 0;
if (!mutable_output_size_list_.empty()) {
return mutable_output_size_list_;
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto op_name = common::AnfAlgo::GetCNodeName(cnode);
@ -186,6 +183,7 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
if (common::AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) {
rank_size = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
}
int64_t fusion = 0;
if (common::AnfAlgo::HasNodeAttr(kAttrFusion, cnode)) {
fusion = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
@ -194,14 +192,22 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
MS_LOG(EXCEPTION) << "Invalid data type size " << hccl_data_type_list_.size() << " diff shape size "
<< hccl_kernel_input_shape_list_.size();
}
ulong loop_size = hccl_data_type_list_.size();
loop_size_ = hccl_data_type_list_.size();
if (common::AnfAlgo::GetInputTensorNum(anf_node) > 1 && op_name == kAllGatherOpName && fusion >= 1) {
loop_size *= static_cast<ulong>(rank_size);
loop_size_ *= static_cast<ulong>(rank_size);
}
if (op_name == kReduceScatterOpName && fusion >= 1) {
loop_size = common::AnfAlgo::GetOutputTensorNum(anf_node);
loop_size_ = common::AnfAlgo::GetOutputTensorNum(anf_node);
}
for (ulong i = 0; i < loop_size; ++i) {
}
const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
size_t size = 0;
if (!mutable_output_size_list_.empty()) {
return mutable_output_size_list_;
}
for (ulong i = 0; i < loop_size_; ++i) {
if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[0], hccl_kernel_output_shape_list_[i], &size)) {
MS_LOG(ERROR) << "GetHcclOpOutputSize failed";
}
@ -344,6 +350,26 @@ bool HcclKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector
return true;
}
void HcclKernel::UpdateOutputSizeList() {
auto anf_node = anf_node_.lock();
if (!anf_node) {
MS_LOG(EXCEPTION) << "anf_node pointer is expired.";
}
size_t size = 0;
hccl_kernel_output_shape_list_.clear();
mutable_output_size_list_.clear();
if (!HcomUtil::GetKernelOutputShape(anf_node, &hccl_kernel_output_shape_list_)) {
MS_LOG(EXCEPTION) << "GetKernelOutputShape fail!";
}
for (ulong i = 0; i < loop_size_; ++i) {
if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[0], hccl_kernel_output_shape_list_[i], &size)) {
MS_LOG(EXCEPTION) << "GetHcclOpOutputSize failed";
}
mutable_output_size_list_.push_back(size);
}
}
int HcclKernel::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
@ -358,7 +384,7 @@ int HcclKernel::Resize(const BaseOperatorPtr &base_operator, const std::vector<K
}
MS_LOG(INFO) << "Start to InitOp. Node info: " << cnode->DebugString();
UpdateOutputSizeList();
std::vector<ShapeVector> hccl_kernel_input_shape_list;
if (!HcomUtil::GetKernelInputShape(cnode, &hccl_kernel_input_shape_list)) {
MS_LOG(EXCEPTION) << "GetKernelInputShape fail! Node info: " << cnode->DebugString();

View File

@ -56,6 +56,8 @@ class HcclKernel : public AscendKernelMod {
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
protected:
void UpdateOutputSizeList();
void CalLoopSize();
std::vector<std::vector<int64_t>> hccl_kernel_input_shape_list_;
std::vector<std::vector<int64_t>> hccl_kernel_output_shape_list_;
std::vector<HcclDataType> hccl_data_type_list_;
@ -72,6 +74,7 @@ class HcclKernel : public AscendKernelMod {
std::string group_;
std::mutex hccl_mutex_;
std::condition_variable cond_;
ulong loop_size_{0};
};
using HcclKernelCreater = std::function<std::shared_ptr<HcclKernel>()>;