From 39fc750e29f0642f61e71523c53b77b827877c9c Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Thu, 17 Nov 2022 19:49:57 +0800 Subject: [PATCH] call tbe checksupported before launch transdata at pynative mode --- .../hal/device/ascend_launch_transdata.cc | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_launch_transdata.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_launch_transdata.cc index f01cafd6419..846bfa8cfff 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_launch_transdata.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_launch_transdata.cc @@ -21,8 +21,25 @@ #include "backend/common/session/single_kernel_graph.h" #include "backend/common/session/anf_runtime_algorithm.h" #include "include/common/utils/anfalgo.h" +#include "plugin/device/ascend/kernel/tbe/tbe_kernel_compile.h" +#include "plugin/device/ascend/kernel/tbe/tbe_json/single_tbe_json_creator.h" namespace mindspore::device::ascend { +namespace { +bool TbeCheckSupported(const CNodePtr &transdata_node) { + MS_EXCEPTION_IF_NULL(transdata_node); + auto &build_manager = kernel::ascend::TbeKernelCompileManager::GetInstance(); + auto json_creator = std::make_shared(); + MS_EXCEPTION_IF_NULL(json_creator); + nlohmann::json kernel_json; + auto ret = json_creator->GenJson(transdata_node, &kernel_json); + if (!ret) { + MS_LOG(EXCEPTION) << "Gen node hash failed. [" << transdata_node->fullname_with_scope() << "]"; + } + ret = build_manager.TbeOpCheckSupported(transdata_node, &kernel_json); + return ret; +} +} // namespace void AscendLaunchTransData::FreeDeviceMem(void *addr) { AscendLaunchKernel::FreeDeviceMem(addr); } size_t AscendLaunchTransData::AlignSizeForLaunchKernel(size_t size) { @@ -114,6 +131,9 @@ void AscendLaunchTransData::ConstructKernelGraphAndSetAttr() { common::AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(dst_format_), transdata_node); common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups_), transdata_node); common::AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups_), transdata_node); + if (!TbeCheckSupported(transdata_node)) { + builder->SetKernelType(KernelType::AICPU_KERNEL); + } } } } // namespace mindspore::device::ascend