Update TopKSplit pass

This commit is contained in:
yuchaojie 2022-07-12 10:08:44 +08:00
parent c218d871be
commit 8029bbf815
8 changed files with 211 additions and 102 deletions

View File

@ -18,6 +18,7 @@
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "runtime/dev.h" #include "runtime/dev.h"
#include "opt_info/opt_info.h" #include "opt_info/opt_info.h"
#include "plugin/device/ascend/hal/hardware/ascend_utils.h"
namespace mindspore { namespace mindspore {
namespace { namespace {
@ -55,18 +56,6 @@ inline std::vector<std::string> SplitStrByRegex(const std::string &str, const st
return std::vector<std::string>(std::sregex_token_iterator(str.begin(), str.end(), split, -1), return std::vector<std::string>(std::sregex_token_iterator(str.begin(), str.end(), split, -1),
std::sregex_token_iterator()); std::sregex_token_iterator());
} }
static std::string GetSocVersion() {
constexpr int kSocVersionLen = 50;
char soc_version[kSocVersionLen] = {0};
auto ret = rtGetSocVersion(soc_version, kSocVersionLen);
if (ret != RT_ERROR_NONE) {
MS_LOG(WARNING) << "rtGetSocVersion failed, ret = " << ret;
return "Ascend910";
}
return soc_version;
}
} // namespace } // namespace
LicManager::LicManager() { ParseSwitch(); } LicManager::LicManager() { ParseSwitch(); }
@ -87,7 +76,7 @@ bool LicManager::GetPassSwitch(OptPassEnum pass) const {
void LicManager::ParseSwitch() { void LicManager::ParseSwitch() {
std::map<std::string, std::string> opt_info_map; std::map<std::string, std::string> opt_info_map;
auto ret = gelc::GetOptInfo(0, GetSocVersion(), opt_info_map); auto ret = gelc::GetOptInfo(0, device::ascend::GetSocVersion(), opt_info_map);
if (ret != 0) { if (ret != 0) {
MS_LOG(WARNING) << "GetOptInfo failed."; MS_LOG(WARNING) << "GetOptInfo failed.";
return; return;

View File

@ -19,6 +19,7 @@
#include <string> #include <string>
#include "backend/common/optimizer/common_backend_optimization.h" #include "backend/common/optimizer/common_backend_optimization.h"
#include "plugin/device/ascend/optimizer/ascend_backend_optimization.h" #include "plugin/device/ascend/optimizer/ascend_backend_optimization.h"
#include "plugin/device/ascend/hal/hardware/ascend_utils.h"
#include "common/graph_kernel/adapter/graph_kernel_optimization.h" #include "common/graph_kernel/adapter/graph_kernel_optimization.h"
#include "common/graph_kernel/adapter/expander.h" #include "common/graph_kernel/adapter/expander.h"
#include "common/graph_kernel/value_graph_binder.h" #include "common/graph_kernel/value_graph_binder.h"
@ -113,6 +114,8 @@ void AscendGraphOptimization::OptimizeGraphWithoutDeviceInfo(const KernelGraphPt
memo_.clear(); memo_.clear();
AddGraphToManager(NOT_NULL(graph), NOT_NULL(graph_manager_)); AddGraphToManager(NOT_NULL(graph), NOT_NULL(graph_manager_));
PlatformInfoInitialization();
memo_.clear(); memo_.clear();
IRFusionOptimization(graph); IRFusionOptimization(graph);
} }

View File

@ -22,11 +22,15 @@
#include "backend/common/session/anf_runtime_algorithm.h" #include "backend/common/session/anf_runtime_algorithm.h"
#include "runtime/device/ms_device_shape_transfer.h" #include "runtime/device/ms_device_shape_transfer.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "runtime/dev.h"
#include "common/util/platform_info.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
constexpr auto kUnknowErrorString = "Unknown error occurred"; constexpr auto kUnknowErrorString = "Unknown error occurred";
constexpr auto kSOC_VERSION = "SOC_VERSION";
void ReportErrorMessage() { void ReportErrorMessage() {
const string &error_message = ErrorManager::GetInstance().GetErrorMessage(); const string &error_message = ErrorManager::GetInstance().GetErrorMessage();
if (!error_message.empty() && error_message.find(kUnknowErrorString) == string::npos) { if (!error_message.empty() && error_message.find(kUnknowErrorString) == string::npos) {
@ -56,6 +60,56 @@ bool IsDynamicShapeGraph(const FuncGraphPtr &func_graph) {
[](const AnfNodePtr &node) { return common::AnfAlgo::IsDynamicShape(node); }); [](const AnfNodePtr &node) { return common::AnfAlgo::IsDynamicShape(node); });
} }
std::string GetSocVersion() {
// Get default soc version.
static std::string version;
if (version.empty()) {
const int kSocVersionLen = 50;
char soc_version[kSocVersionLen] = {0};
auto ret = rtGetSocVersion(soc_version, kSocVersionLen);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "GetSocVersion failed.";
}
// Get soc version from env value.
const char *soc_version_env = nullptr;
std::string str_soc_version_env = common::GetEnv(kSOC_VERSION);
if (!str_soc_version_env.empty()) {
soc_version_env = common::SafeCStr(str_soc_version_env);
}
if (soc_version_env != nullptr) {
if (std::strcmp(soc_version, soc_version_env) != 0) {
MS_LOG(DEBUG) << "Detected the env SOC_VERSION, so the SocVersion will be changed to " << str_soc_version_env
<< ".";
ret = rtSetSocVersion(soc_version_env);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "SetSocVersion failed, errorno: " << ret;
}
version = soc_version_env;
return soc_version_env;
}
}
version = soc_version;
}
return version;
}
void PlatformInfoInitialization() {
auto soc_version = GetSocVersion();
fe::PlatformInfo platform_info;
fe::OptionalInfo opti_compilation_info;
fe::PlatformInfoManager &inst = fe::PlatformInfoManager::Instance();
if (inst.InitializePlatformInfo() != 0) {
MS_LOG(WARNING) << "Initialize PlatformInfo failed.";
return;
}
if (inst.GetPlatformInfo(soc_version, platform_info, opti_compilation_info) != 0) {
MS_LOG(WARNING) << "GetPlatformInfo failed.";
return;
}
opti_compilation_info.soc_version = soc_version;
inst.SetOptionalCompilationInfo(opti_compilation_info);
}
void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph, const device::DeviceContext *device_context) { void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph, const device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto outputs = common::AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); auto outputs = common::AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_UTILS_H_ #ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_UTILS_H_
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_UTILS_H_ #define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_UTILS_H_
#include <string>
#include "plugin/device/ascend/hal/hardware/ascend_device_context.h" #include "plugin/device/ascend/hal/hardware/ascend_device_context.h"
#include "backend/common/session/kernel_graph.h" #include "backend/common/session/kernel_graph.h"
@ -30,6 +31,9 @@ void SetErrorManagerContext();
bool IsGraphMode(); bool IsGraphMode();
bool IsDynamicShapeGraph(const FuncGraphPtr &func_graph); bool IsDynamicShapeGraph(const FuncGraphPtr &func_graph);
std::string GetSocVersion();
void PlatformInfoInitialization();
// Some NOP nodes have be hide in execution order, it doesn't have output device address, this function creates // Some NOP nodes have be hide in execution order, it doesn't have output device address, this function creates
// output device address for these nodes, and the output device address is the same with input device address. // output device address for these nodes, and the output device address is the same with input device address.
void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph, const device::DeviceContext *device_context); void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph, const device::DeviceContext *device_context);

View File

@ -41,6 +41,7 @@
#include "mindspore/ccsrc/include/common/debug/common.h" #include "mindspore/ccsrc/include/common/debug/common.h"
#include "kernel/common_utils.h" #include "kernel/common_utils.h"
#include "mindspore/core/utils/file_utils.h" #include "mindspore/core/utils/file_utils.h"
#include "plugin/device/ascend/hal/hardware/ascend_utils.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
@ -48,7 +49,6 @@ namespace tbe {
constexpr auto kCceKernelMeta = "kernel_meta/"; constexpr auto kCceKernelMeta = "kernel_meta/";
constexpr auto kJsonSuffix = ".json"; constexpr auto kJsonSuffix = ".json";
constexpr auto kInfoSuffix = ".info"; constexpr auto kInfoSuffix = ".info";
constexpr auto kSOC_VERSION = "SOC_VERSION";
constexpr auto kBuildRes = "build_result"; constexpr auto kBuildRes = "build_result";
constexpr auto kTUNE_BANK_PATH = "TUNE_BANK_PATH"; constexpr auto kTUNE_BANK_PATH = "TUNE_BANK_PATH";
constexpr auto kTUNE_DUMP_PATH = "TUNE_DUMP_PATH"; constexpr auto kTUNE_DUMP_PATH = "TUNE_DUMP_PATH";
@ -148,7 +148,7 @@ nlohmann::json TbeUtils::GenSocInfo() {
soc_info_json["l1Fusion"] = "false"; soc_info_json["l1Fusion"] = "false";
soc_info_json["l2Fusion"] = "false"; soc_info_json["l2Fusion"] = "false";
soc_info_json["op_bank_update"] = false; soc_info_json["op_bank_update"] = false;
soc_info_json["socVersion"] = GetSocVersion(); soc_info_json["socVersion"] = device::ascend::GetSocVersion();
soc_info_json["offlineTune"] = CheckOfflineTune(); soc_info_json["offlineTune"] = CheckOfflineTune();
soc_info_json["op_debug_dir"] = GetOpDebugPath(); soc_info_json["op_debug_dir"] = GetOpDebugPath();
soc_info_json["op_debug_level"] = GetOpDebugLevel(); soc_info_json["op_debug_level"] = GetOpDebugLevel();
@ -472,39 +472,6 @@ bool TbeUtils::CheckOfflineTune() {
return offline; return offline;
} }
std::string TbeUtils::GetSocVersion() {
// Get default soc version.
static std::string version;
if (version.empty()) {
const int kSocVersionLen = 50;
char soc_version[kSocVersionLen] = {0};
auto ret = rtGetSocVersion(soc_version, kSocVersionLen);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "GetSocVersion failed.";
}
// Get soc version from env value.
const char *soc_version_env = nullptr;
std::string str_soc_version_env = common::GetEnv(kSOC_VERSION);
if (!str_soc_version_env.empty()) {
soc_version_env = common::SafeCStr(str_soc_version_env);
}
if (soc_version_env != nullptr) {
if (std::strcmp(soc_version, soc_version_env) != 0) {
MS_LOG(DEBUG) << "Detected the env SOC_VERSION, so the SocVersion will be changed to " << str_soc_version_env
<< ".";
ret = rtSetSocVersion(soc_version_env);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "SetSocVersion failed, errorno: " << ret;
}
version = soc_version_env;
return soc_version_env;
}
}
version = soc_version;
}
return version;
}
KernelPackPtr KernelMeta::LoadFromFile(const std::string &kernel_name) const { KernelPackPtr KernelMeta::LoadFromFile(const std::string &kernel_name) const {
auto config_path = TbeUtils::GetOpDebugPath(); auto config_path = TbeUtils::GetOpDebugPath();
std::string cce_json = config_path + kCceKernelMeta + kernel_name + kJsonSuffix; std::string cce_json = config_path + kCceKernelMeta + kernel_name + kJsonSuffix;

View File

@ -56,8 +56,6 @@ class TbeUtils {
static nlohmann::json GenSocInfo(); static nlohmann::json GenSocInfo();
static std::string GetSocVersion();
static std::string GetOpDebugPath(); static std::string GetOpDebugPath();
static std::string GetBankPath(); static std::string GetBankPath();

View File

@ -29,60 +29,89 @@
#include "runtime/device/kernel_info.h" #include "runtime/device/kernel_info.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "plugin/device/ascend/optimizer/optimizer_factory.h" #include "plugin/device/ascend/optimizer/optimizer_factory.h"
#include "common/util/platform_info.h"
namespace mindspore::opt { namespace mindspore::opt {
namespace { namespace {
constexpr size_t kFloat16Len = 2; // size of float16; constexpr size_t kMultiply2 = 2;
constexpr size_t kTopkIndexK = 1; constexpr size_t kTopkIndexK = 1;
constexpr auto kAttrSorted = "sorted"; constexpr auto kAttrSorted = "sorted";
tensor::TensorPtr CreateTensor() { tensor::TensorPtr ConstructAssistTensor(size_t assist_len, bool is_segment_sort = false, bool is_int32 = false) {
// 1 create tensor // create tensor
const size_t last_dim = 4096; int64_t shape_len = is_segment_sort ? SizeToLong(assist_len) : SizeToLong(assist_len * kMultiply2);
std::vector<int64_t> indices_shape = {SizeToLong(last_dim * 2)}; std::vector<int64_t> assist_shape{shape_len};
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16); auto dtype = is_int32 ? kInt32 : kFloat16;
MS_EXCEPTION_IF_NULL(tensor_type); TensorTypePtr tensor_type = std::make_shared<TensorType>(dtype);
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type}; tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type};
tensor::TensorPtr indices_tensor = std::make_shared<tensor::Tensor>(kFloat16->type_id(), indices_shape); tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(dtype->type_id(), assist_shape);
MS_EXCEPTION_IF_NULL(indices_tensor); assist_tensor->set_device_info(device_info);
indices_tensor->set_device_info(device_info);
// 2 set value of tensor // set value of tensor
auto data_ptr = indices_tensor->data_c(); auto data_ptr = assist_tensor->data_c();
MS_EXCEPTION_IF_NULL(data_ptr); MS_EXCEPTION_IF_NULL(data_ptr);
std::vector<float16> half_data; if (is_int32) {
for (size_t i = 0; i < last_dim; ++i) { auto data = static_cast<int32_t *>(data_ptr);
(void)half_data.emplace_back(float16(static_cast<float>(i))); for (int32_t i = 0; i < SizeToInt(assist_len); ++i) {
*data = i;
++data;
}
} else {
auto data = static_cast<float16 *>(data_ptr);
for (size_t i = 0; i < assist_len; ++i) {
*data = float16(static_cast<float>(i));
++data;
}
if (!is_segment_sort) {
for (size_t i = 0; i < assist_len; ++i) {
auto gap = static_cast<int>(i) - static_cast<int>(float16(static_cast<float>(i)));
*data = float16(static_cast<float>(gap));
++data;
}
}
} }
for (size_t i = 0; i < last_dim; ++i) {
auto gap = static_cast<int>(i) - static_cast<int>(float16(static_cast<float>(i))); return assist_tensor;
(void)half_data.emplace_back(float16(static_cast<float>(gap)));
}
auto elem_num = last_dim * kFloat16Len * 2;
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(indices_tensor->data().nbytes()),
static_cast<void *>(half_data.data()), elem_num);
if (ret_code != 0) {
MS_LOG(ERROR) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code;
return nullptr;
}
return indices_tensor;
} }
ValueNodePtr CreateValueNode() { tensor::TensorPtr CreateAssistTensor(const std::vector<int64_t> &input_shape, int32_t k_num,
tensor::TensorPtr indices_tensor = CreateTensor(); const fe::PlatformInfo &platform_info, const fe::OptionalInfo &optional_info) {
MS_EXCEPTION_IF_NULL(indices_tensor); bool is_lhisi = optional_info.soc_version.find("Hi3796CV300CS") != std::string::npos ||
auto indices_const = std::make_shared<ValueNode>(indices_tensor); optional_info.soc_version.find("Hi3796CV300ES") != std::string::npos ||
MS_EXCEPTION_IF_NULL(indices_const); optional_info.soc_version.find("SD3403") != std::string::npos;
auto indices_abstract = indices_tensor->ToAbstract(); constexpr int64_t kLhisiMaxLastSize = 3000;
indices_const->set_abstract(indices_abstract); constexpr int64_t kHisiMaxLastSize = 5000;
auto indices_kernel_info = std::make_shared<device::KernelInfo>(); constexpr int64_t kLhisiMaxKNum = 2048;
MS_EXCEPTION_IF_NULL(indices_kernel_info); constexpr int64_t kHisiMaxKNum = 4096;
indices_const->set_kernel_info(indices_kernel_info); constexpr size_t kSmallSceneAssistLen = 4096;
constexpr size_t kLargeSceneAssistLen = 2048;
int64_t max_last_size = is_lhisi ? kLhisiMaxLastSize : kHisiMaxLastSize;
int64_t max_k_num = is_lhisi ? kLhisiMaxKNum : kHisiMaxKNum;
if (input_shape.back() > max_last_size || k_num > max_k_num) {
if (platform_info.str_info.short_soc_version == "Ascend910B" ||
platform_info.str_info.short_soc_version == "Ascend310B") {
return ConstructAssistTensor(kLargeSceneAssistLen, true, true);
} else {
return ConstructAssistTensor(kLargeSceneAssistLen, true);
}
}
return ConstructAssistTensor(kSmallSceneAssistLen);
}
ValueNodePtr CreateAssistNode(const std::vector<int64_t> &input_shape, int32_t k_num,
const fe::PlatformInfo &platform_info, const fe::OptionalInfo &optional_info) {
tensor::TensorPtr assist_tensor = CreateAssistTensor(input_shape, k_num, platform_info, optional_info);
MS_EXCEPTION_IF_NULL(assist_tensor);
auto assist_const = std::make_shared<ValueNode>(assist_tensor);
auto assist_abstract = assist_tensor->ToAbstract();
assist_const->set_abstract(assist_abstract);
auto assist_kernel_info = std::make_shared<device::KernelInfo>();
assist_const->set_kernel_info(assist_kernel_info);
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1; kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
builder1.SetOutputsFormat({kOpFormat_DEFAULT}); builder1.SetOutputsFormat({kOpFormat_DEFAULT});
builder1.SetOutputsDeviceType({kNumberTypeFloat16}); builder1.SetOutputsDeviceType({common::AnfAlgo::GetOutputInferDataType(assist_const, 0)});
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), indices_const.get()); AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), assist_const.get());
return indices_const; return assist_const;
} }
kernel::KernelBuildInfoPtr CreateKernelBuildInfo() { kernel::KernelBuildInfoPtr CreateKernelBuildInfo() {
@ -106,13 +135,13 @@ bool CheckInputNamesSize(const CNodePtr &cnode) {
return true; return true;
} }
bool CheckOutputShape(const AnfNodePtr &node) { bool CheckInputShape(const AnfNodePtr &node) {
auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, 0); auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
if (shape.empty()) { if (shape.empty()) {
MS_LOG(INFO) << "The output shape of topk to split must not be empty"; MS_LOG(INFO) << "The input shape of topk to split must not be empty";
return false; return false;
} }
auto last_dim = shape[shape.size() - 1]; auto last_dim = shape.back();
const int64_t kMaxFloat16 = 65500; const int64_t kMaxFloat16 = 65500;
if (last_dim > kMaxFloat16) { if (last_dim > kMaxFloat16) {
MS_LOG(INFO) << "The last dim is more than " << kMaxFloat16 << ", switch to aicpu ops."; MS_LOG(INFO) << "The last dim is more than " << kMaxFloat16 << ", switch to aicpu ops.";
@ -132,13 +161,13 @@ bool CheckInputType(const AnfNodePtr &node) {
} }
bool CheckFusion(const CNodePtr &node) { bool CheckFusion(const CNodePtr &node) {
if (!common::AnfAlgo::HasNodeAttr(kAttrSorted, node) || !common::AnfAlgo::GetNodeAttr<bool>(node, kAttrSorted)) { if (common::AnfAlgo::HasNodeAttr(kAttrSorted, node) && !common::AnfAlgo::GetNodeAttr<bool>(node, kAttrSorted)) {
return false; return false;
} }
if (!CheckInputNamesSize(node)) { if (!CheckInputNamesSize(node)) {
return false; return false;
} }
if (!CheckOutputShape(node)) { if (!CheckInputShape(node)) {
return false; return false;
} }
if (!CheckInputType(node)) { if (!CheckInputType(node)) {
@ -182,28 +211,37 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
if (!IsValueNode<tensor::Tensor>(input_k)) { if (!IsValueNode<tensor::Tensor>(input_k)) {
return nullptr; return nullptr;
} }
fe::PlatformInfo platform_info;
fe::OptionalInfo optional_info;
if (fe::PlatformInfoManager::Instance().GetPlatformInfoWithOutSocVersion(platform_info, optional_info)) {
MS_LOG(WARNING) << "Get platform info failed, quit fusion.";
return nullptr;
}
ValuePtr value = GetValueNode(input_k); ValuePtr value = GetValueNode(input_k);
MS_EXCEPTION_IF_NULL(value); MS_EXCEPTION_IF_NULL(value);
auto tensor = value->cast<tensor::TensorPtr>(); auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
auto *data = static_cast<int32_t *>(tensor->data_c()); auto *data = static_cast<int32_t *>(tensor->data_c());
MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(data);
auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data)); int32_t k_num = *data;
auto new_value_node = std::make_shared<ValueNode>(MakeValue(k_num));
new_cnode->set_input(kTopkIndexK + 1, new_value_node); new_cnode->set_input(kTopkIndexK + 1, new_value_node);
mindspore::HashSet<size_t> attr_index{kTopkIndexK}; mindspore::HashSet<size_t> attr_index{kTopkIndexK};
ConstInputToAttr(new_cnode, attr_index); ConstInputToAttr(new_cnode, attr_index);
auto indices_const = CreateValueNode(); auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(new_cnode, 0);
new_cnode->add_input(indices_const); auto assist_const = CreateAssistNode(input_shape, k_num, platform_info, optional_info);
new_cnode->add_input(assist_const);
MS_EXCEPTION_IF_NULL(supported_checker_); MS_EXCEPTION_IF_NULL(supported_checker_);
if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) { if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) {
MS_LOG(INFO) << "split topk failed, check to aicpu."; MS_LOG(INFO) << "Split topk failed, check to aicpu.";
return nullptr; return nullptr;
} }
if (kernel_graph != nullptr) { if (kernel_graph != nullptr) {
MS_LOG(INFO) << "split topk success. use tbe aicore."; MS_LOG(INFO) << "Split topk success. use tbe aicore.";
kernel_graph->AddValueNodeToGraph(indices_const); kernel_graph->AddValueNodeToGraph(assist_const);
} }
return new_cnode; return new_cnode;

View File

@ -0,0 +1,56 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/util/platform_info.h"
namespace fe {
PlatformInfoManager &PlatformInfoManager::Instance() {
static PlatformInfoManager instance{};
return instance;
}
uint32_t PlatformInfoManager::InitializePlatformInfo() { return 0; }
uint32_t PlatformInfoManager::Finalize() { return 0; }
uint32_t PlatformInfoManager::GetPlatformInfo(const std::string soc_version, PlatformInfo &platform_info,
OptionalInfo &optional_info) {
return 0;
}
uint32_t PlatformInfoManager::GetPlatformInfoWithOutSocVersion(PlatformInfo &platform_info,
OptionalInfo &optional_info) {
return 0;
}
void PlatformInfoManager::SetOptionalCompilationInfo(OptionalInfo &optional_info) {}
uint32_t PlatformInfoManager::GetPlatformInfos(const std::string soc_version, PlatFormInfos &platform_info,
OptionalInfos &optional_info) {
return 0;
}
uint32_t PlatformInfoManager::GetPlatformInfoWithOutSocVersion(PlatFormInfos &platform_infos,
OptionalInfos &optional_infos) {
return 0;
}
void PlatformInfoManager::SetOptionalCompilationInfo(OptionalInfos &optional_infos) {}
PlatformInfoManager::PlatformInfoManager() {}
PlatformInfoManager::~PlatformInfoManager() {}
} // namespace fe