!38219 fix hccl
Merge pull request !38219 from chengbin/fix_hccl_master
This commit is contained in:
commit
851f52ce86
|
@ -143,6 +143,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) {
|
|||
common::AnfAlgo::SetNodeAttr(kAttrComm, MakeValue<int64_t>((int64_t)comm), anf_node);
|
||||
}
|
||||
anf_node_ = anf_node;
|
||||
CalLoopSize();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -170,15 +171,11 @@ const std::vector<size_t> &HcclKernel::GetInputSizeList() const {
|
|||
return mutable_input_size_list_;
|
||||
}
|
||||
|
||||
const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
|
||||
void HcclKernel::CalLoopSize() {
|
||||
auto anf_node = anf_node_.lock();
|
||||
if (!anf_node) {
|
||||
MS_LOG(EXCEPTION) << "anf_node pointer is expired.";
|
||||
}
|
||||
size_t size = 0;
|
||||
if (!mutable_output_size_list_.empty()) {
|
||||
return mutable_output_size_list_;
|
||||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto op_name = common::AnfAlgo::GetCNodeName(cnode);
|
||||
|
@ -186,6 +183,7 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
|
|||
if (common::AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) {
|
||||
rank_size = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
|
||||
}
|
||||
|
||||
int64_t fusion = 0;
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrFusion, cnode)) {
|
||||
fusion = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
|
||||
|
@ -194,14 +192,22 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
|
|||
MS_LOG(EXCEPTION) << "Invalid data type size " << hccl_data_type_list_.size() << " diff shape size "
|
||||
<< hccl_kernel_input_shape_list_.size();
|
||||
}
|
||||
ulong loop_size = hccl_data_type_list_.size();
|
||||
loop_size_ = hccl_data_type_list_.size();
|
||||
if (common::AnfAlgo::GetInputTensorNum(anf_node) > 1 && op_name == kAllGatherOpName && fusion >= 1) {
|
||||
loop_size *= static_cast<ulong>(rank_size);
|
||||
loop_size_ *= static_cast<ulong>(rank_size);
|
||||
}
|
||||
if (op_name == kReduceScatterOpName && fusion >= 1) {
|
||||
loop_size = common::AnfAlgo::GetOutputTensorNum(anf_node);
|
||||
loop_size_ = common::AnfAlgo::GetOutputTensorNum(anf_node);
|
||||
}
|
||||
for (ulong i = 0; i < loop_size; ++i) {
|
||||
}
|
||||
|
||||
const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
|
||||
size_t size = 0;
|
||||
if (!mutable_output_size_list_.empty()) {
|
||||
return mutable_output_size_list_;
|
||||
}
|
||||
|
||||
for (ulong i = 0; i < loop_size_; ++i) {
|
||||
if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[0], hccl_kernel_output_shape_list_[i], &size)) {
|
||||
MS_LOG(ERROR) << "GetHcclOpOutputSize failed";
|
||||
}
|
||||
|
@ -344,6 +350,26 @@ bool HcclKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector
|
|||
return true;
|
||||
}
|
||||
|
||||
void HcclKernel::UpdateOutputSizeList() {
|
||||
auto anf_node = anf_node_.lock();
|
||||
if (!anf_node) {
|
||||
MS_LOG(EXCEPTION) << "anf_node pointer is expired.";
|
||||
}
|
||||
size_t size = 0;
|
||||
hccl_kernel_output_shape_list_.clear();
|
||||
mutable_output_size_list_.clear();
|
||||
if (!HcomUtil::GetKernelOutputShape(anf_node, &hccl_kernel_output_shape_list_)) {
|
||||
MS_LOG(EXCEPTION) << "GetKernelOutputShape fail!";
|
||||
}
|
||||
|
||||
for (ulong i = 0; i < loop_size_; ++i) {
|
||||
if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[0], hccl_kernel_output_shape_list_[i], &size)) {
|
||||
MS_LOG(EXCEPTION) << "GetHcclOpOutputSize failed";
|
||||
}
|
||||
mutable_output_size_list_.push_back(size);
|
||||
}
|
||||
}
|
||||
|
||||
int HcclKernel::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
|
@ -358,7 +384,7 @@ int HcclKernel::Resize(const BaseOperatorPtr &base_operator, const std::vector<K
|
|||
}
|
||||
|
||||
MS_LOG(INFO) << "Start to InitOp. Node info: " << cnode->DebugString();
|
||||
|
||||
UpdateOutputSizeList();
|
||||
std::vector<ShapeVector> hccl_kernel_input_shape_list;
|
||||
if (!HcomUtil::GetKernelInputShape(cnode, &hccl_kernel_input_shape_list)) {
|
||||
MS_LOG(EXCEPTION) << "GetKernelInputShape fail! Node info: " << cnode->DebugString();
|
||||
|
|
|
@ -56,6 +56,8 @@ class HcclKernel : public AscendKernelMod {
|
|||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
protected:
|
||||
void UpdateOutputSizeList();
|
||||
void CalLoopSize();
|
||||
std::vector<std::vector<int64_t>> hccl_kernel_input_shape_list_;
|
||||
std::vector<std::vector<int64_t>> hccl_kernel_output_shape_list_;
|
||||
std::vector<HcclDataType> hccl_data_type_list_;
|
||||
|
@ -72,6 +74,7 @@ class HcclKernel : public AscendKernelMod {
|
|||
std::string group_;
|
||||
std::mutex hccl_mutex_;
|
||||
std::condition_variable cond_;
|
||||
ulong loop_size_{0};
|
||||
};
|
||||
|
||||
using HcclKernelCreater = std::function<std::shared_ptr<HcclKernel>()>;
|
||||
|
|
Loading…
Reference in New Issue