call tbe checksupported before launch transdata at pynative mode
This commit is contained in:
parent
2a56d387cf
commit
39fc750e29
|
@ -21,8 +21,25 @@
|
||||||
#include "backend/common/session/single_kernel_graph.h"
|
#include "backend/common/session/single_kernel_graph.h"
|
||||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||||
#include "include/common/utils/anfalgo.h"
|
#include "include/common/utils/anfalgo.h"
|
||||||
|
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_compile.h"
|
||||||
|
#include "plugin/device/ascend/kernel/tbe/tbe_json/single_tbe_json_creator.h"
|
||||||
|
|
||||||
namespace mindspore::device::ascend {
|
namespace mindspore::device::ascend {
|
||||||
|
namespace {
|
||||||
|
bool TbeCheckSupported(const CNodePtr &transdata_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(transdata_node);
|
||||||
|
auto &build_manager = kernel::ascend::TbeKernelCompileManager::GetInstance();
|
||||||
|
auto json_creator = std::make_shared<kernel::SelectTbeJsonCreator>();
|
||||||
|
MS_EXCEPTION_IF_NULL(json_creator);
|
||||||
|
nlohmann::json kernel_json;
|
||||||
|
auto ret = json_creator->GenJson(transdata_node, &kernel_json);
|
||||||
|
if (!ret) {
|
||||||
|
MS_LOG(EXCEPTION) << "Gen node hash failed. [" << transdata_node->fullname_with_scope() << "]";
|
||||||
|
}
|
||||||
|
ret = build_manager.TbeOpCheckSupported(transdata_node, &kernel_json);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
void AscendLaunchTransData::FreeDeviceMem(void *addr) { AscendLaunchKernel::FreeDeviceMem(addr); }
|
void AscendLaunchTransData::FreeDeviceMem(void *addr) { AscendLaunchKernel::FreeDeviceMem(addr); }
|
||||||
|
|
||||||
size_t AscendLaunchTransData::AlignSizeForLaunchKernel(size_t size) {
|
size_t AscendLaunchTransData::AlignSizeForLaunchKernel(size_t size) {
|
||||||
|
@ -114,6 +131,9 @@ void AscendLaunchTransData::ConstructKernelGraphAndSetAttr() {
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(dst_format_), transdata_node);
|
common::AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(dst_format_), transdata_node);
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups_), transdata_node);
|
common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups_), transdata_node);
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups_), transdata_node);
|
common::AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(groups_), transdata_node);
|
||||||
|
if (!TbeCheckSupported(transdata_node)) {
|
||||||
|
builder->SetKernelType(KernelType::AICPU_KERNEL);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace mindspore::device::ascend
|
} // namespace mindspore::device::ascend
|
||||||
|
|
Loading…
Reference in New Issue