forked from mindspore-Ecosystem/mindspore
!39909 Support tbe atomic clean for akg fusion op
Merge pull request !39909 from NaCN/tbe_atomic_clean_for_akg
This commit is contained in:
commit
1e215065fa
|
@ -428,7 +428,9 @@ void TagNeedInsertAtomicAttr(const std::vector<CNodePtr> &nodes) {
|
|||
}
|
||||
common::AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrNeedAtomic, MakeValue(true), anf_node);
|
||||
} else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL && IsAtomicNode(anf_node)) {
|
||||
} else if ((AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL ||
|
||||
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) &&
|
||||
IsAtomicNode(anf_node)) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrNeedAtomic, MakeValue(true), anf_node);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -130,5 +130,10 @@ std::vector<TaskInfoPtr> AkgKernelMod::GenTask(const std::vector<AddressPtr> &in
|
|||
input_data_addrs, output_data_addrs, workspace_addrs, NeedDump());
|
||||
return {task_info_ptr};
|
||||
}
|
||||
|
||||
std::vector<size_t> AkgKernelMod::GenParameters() {
|
||||
auto kernel_json_info = kernel_pack_->kernel_json_info();
|
||||
return kernel_json_info.parameters;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,6 +32,7 @@ class AkgKernelMod : public AscendKernelMod {
|
|||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
|
||||
std::vector<size_t> GenParameters() override;
|
||||
|
||||
private:
|
||||
KernelPackPtr kernel_pack_;
|
||||
|
|
|
@ -464,7 +464,8 @@ bool AnfAlgo::HasNodeAttr(const std::string &key, const CNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
// call node's input0 is not a primitive.
|
||||
if (!IsValueNode<Primitive>(node->cast<CNodePtr>()->input(0))) {
|
||||
if (!IsValueNode<FuncGraph>(node->cast<CNodePtr>()->input(0)) &&
|
||||
!IsValueNode<Primitive>(node->cast<CNodePtr>()->input(0))) {
|
||||
return false;
|
||||
}
|
||||
// single op cnode.
|
||||
|
|
Loading…
Reference in New Issue