call tbe checksupported before launch transdata at pynative mode

This commit is contained in:
zhoufeng 2022-11-17 19:49:57 +08:00
parent 2a56d387cf
commit 39fc750e29
1 changed files with 20 additions and 0 deletions

View File

@ -21,8 +21,25 @@
#include "backend/common/session/single_kernel_graph.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_compile.h"
#include "plugin/device/ascend/kernel/tbe/tbe_json/single_tbe_json_creator.h"
namespace mindspore::device::ascend {
namespace {
bool TbeCheckSupported(const CNodePtr &transdata_node) {
MS_EXCEPTION_IF_NULL(transdata_node);
auto &build_manager = kernel::ascend::TbeKernelCompileManager::GetInstance();
auto json_creator = std::make_shared<kernel::SelectTbeJsonCreator>();
MS_EXCEPTION_IF_NULL(json_creator);
nlohmann::json kernel_json;
auto ret = json_creator->GenJson(transdata_node, &kernel_json);
if (!ret) {
MS_LOG(EXCEPTION) << "Gen node hash failed. [" << transdata_node->fullname_with_scope() << "]";
}
ret = build_manager.TbeOpCheckSupported(transdata_node, &kernel_json);
return ret;
}
} // namespace
void AscendLaunchTransData::FreeDeviceMem(void *addr) { AscendLaunchKernel::FreeDeviceMem(addr); }
size_t AscendLaunchTransData::AlignSizeForLaunchKernel(size_t size) {
@ -114,6 +131,9 @@ void AscendLaunchTransData::ConstructKernelGraphAndSetAttr() {
common::AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(dst_format_), transdata_node);
common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups_), transdata_node);
common::AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups_), transdata_node);
if (!TbeCheckSupported(transdata_node)) {
builder->SetKernelType(KernelType::AICPU_KERNEL);
}
}
}
} // namespace mindspore::device::ascend