!39909 Support tbe atomic clean for akg fusion op

Merge pull request !39909 from NaCN/tbe_atomic_clean_for_akg
This commit is contained in:
i-robot 2022-08-11 09:22:25 +00:00 committed by Gitee
commit 1e215065fa
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 11 additions and 2 deletions

View File

@ -428,7 +428,9 @@ void TagNeedInsertAtomicAttr(const std::vector<CNodePtr> &nodes) {
}
common::AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node);
common::AnfAlgo::SetNodeAttr(kAttrNeedAtomic, MakeValue(true), anf_node);
} else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL && IsAtomicNode(anf_node)) {
} else if ((AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL ||
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) &&
IsAtomicNode(anf_node)) {
common::AnfAlgo::SetNodeAttr(kAttrNeedAtomic, MakeValue(true), anf_node);
}
}

View File

@ -130,5 +130,10 @@ std::vector<TaskInfoPtr> AkgKernelMod::GenTask(const std::vector<AddressPtr> &in
input_data_addrs, output_data_addrs, workspace_addrs, NeedDump());
return {task_info_ptr};
}
std::vector<size_t> AkgKernelMod::GenParameters() {
auto kernel_json_info = kernel_pack_->kernel_json_info();
return kernel_json_info.parameters;
}
} // namespace kernel
} // namespace mindspore

View File

@ -32,6 +32,7 @@ class AkgKernelMod : public AscendKernelMod {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
std::vector<size_t> GenParameters() override;
private:
KernelPackPtr kernel_pack_;

View File

@ -464,7 +464,8 @@ bool AnfAlgo::HasNodeAttr(const std::string &key, const CNodePtr &node) {
return false;
}
// call node's input0 is not a primitive.
if (!IsValueNode<Primitive>(node->cast<CNodePtr>()->input(0))) {
if (!IsValueNode<FuncGraph>(node->cast<CNodePtr>()->input(0)) &&
!IsValueNode<Primitive>(node->cast<CNodePtr>()->input(0))) {
return false;
}
// single op cnode.