Update TopKSplit pass
This commit is contained in:
parent
c218d871be
commit
8029bbf815
|
@ -18,6 +18,7 @@
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "runtime/dev.h"
|
#include "runtime/dev.h"
|
||||||
#include "opt_info/opt_info.h"
|
#include "opt_info/opt_info.h"
|
||||||
|
#include "plugin/device/ascend/hal/hardware/ascend_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -55,18 +56,6 @@ inline std::vector<std::string> SplitStrByRegex(const std::string &str, const st
|
||||||
return std::vector<std::string>(std::sregex_token_iterator(str.begin(), str.end(), split, -1),
|
return std::vector<std::string>(std::sregex_token_iterator(str.begin(), str.end(), split, -1),
|
||||||
std::sregex_token_iterator());
|
std::sregex_token_iterator());
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string GetSocVersion() {
|
|
||||||
constexpr int kSocVersionLen = 50;
|
|
||||||
char soc_version[kSocVersionLen] = {0};
|
|
||||||
auto ret = rtGetSocVersion(soc_version, kSocVersionLen);
|
|
||||||
if (ret != RT_ERROR_NONE) {
|
|
||||||
MS_LOG(WARNING) << "rtGetSocVersion failed, ret = " << ret;
|
|
||||||
return "Ascend910";
|
|
||||||
}
|
|
||||||
|
|
||||||
return soc_version;
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LicManager::LicManager() { ParseSwitch(); }
|
LicManager::LicManager() { ParseSwitch(); }
|
||||||
|
@ -87,7 +76,7 @@ bool LicManager::GetPassSwitch(OptPassEnum pass) const {
|
||||||
|
|
||||||
void LicManager::ParseSwitch() {
|
void LicManager::ParseSwitch() {
|
||||||
std::map<std::string, std::string> opt_info_map;
|
std::map<std::string, std::string> opt_info_map;
|
||||||
auto ret = gelc::GetOptInfo(0, GetSocVersion(), opt_info_map);
|
auto ret = gelc::GetOptInfo(0, device::ascend::GetSocVersion(), opt_info_map);
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
MS_LOG(WARNING) << "GetOptInfo failed.";
|
MS_LOG(WARNING) << "GetOptInfo failed.";
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/common/optimizer/common_backend_optimization.h"
|
#include "backend/common/optimizer/common_backend_optimization.h"
|
||||||
#include "plugin/device/ascend/optimizer/ascend_backend_optimization.h"
|
#include "plugin/device/ascend/optimizer/ascend_backend_optimization.h"
|
||||||
|
#include "plugin/device/ascend/hal/hardware/ascend_utils.h"
|
||||||
#include "common/graph_kernel/adapter/graph_kernel_optimization.h"
|
#include "common/graph_kernel/adapter/graph_kernel_optimization.h"
|
||||||
#include "common/graph_kernel/adapter/expander.h"
|
#include "common/graph_kernel/adapter/expander.h"
|
||||||
#include "common/graph_kernel/value_graph_binder.h"
|
#include "common/graph_kernel/value_graph_binder.h"
|
||||||
|
@ -113,6 +114,8 @@ void AscendGraphOptimization::OptimizeGraphWithoutDeviceInfo(const KernelGraphPt
|
||||||
memo_.clear();
|
memo_.clear();
|
||||||
AddGraphToManager(NOT_NULL(graph), NOT_NULL(graph_manager_));
|
AddGraphToManager(NOT_NULL(graph), NOT_NULL(graph_manager_));
|
||||||
|
|
||||||
|
PlatformInfoInitialization();
|
||||||
|
|
||||||
memo_.clear();
|
memo_.clear();
|
||||||
IRFusionOptimization(graph);
|
IRFusionOptimization(graph);
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,11 +22,15 @@
|
||||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||||
#include "runtime/device/ms_device_shape_transfer.h"
|
#include "runtime/device/ms_device_shape_transfer.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
|
#include "runtime/dev.h"
|
||||||
|
#include "common/util/platform_info.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace ascend {
|
namespace ascend {
|
||||||
constexpr auto kUnknowErrorString = "Unknown error occurred";
|
constexpr auto kUnknowErrorString = "Unknown error occurred";
|
||||||
|
constexpr auto kSOC_VERSION = "SOC_VERSION";
|
||||||
|
|
||||||
void ReportErrorMessage() {
|
void ReportErrorMessage() {
|
||||||
const string &error_message = ErrorManager::GetInstance().GetErrorMessage();
|
const string &error_message = ErrorManager::GetInstance().GetErrorMessage();
|
||||||
if (!error_message.empty() && error_message.find(kUnknowErrorString) == string::npos) {
|
if (!error_message.empty() && error_message.find(kUnknowErrorString) == string::npos) {
|
||||||
|
@ -56,6 +60,56 @@ bool IsDynamicShapeGraph(const FuncGraphPtr &func_graph) {
|
||||||
[](const AnfNodePtr &node) { return common::AnfAlgo::IsDynamicShape(node); });
|
[](const AnfNodePtr &node) { return common::AnfAlgo::IsDynamicShape(node); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string GetSocVersion() {
|
||||||
|
// Get default soc version.
|
||||||
|
static std::string version;
|
||||||
|
if (version.empty()) {
|
||||||
|
const int kSocVersionLen = 50;
|
||||||
|
char soc_version[kSocVersionLen] = {0};
|
||||||
|
auto ret = rtGetSocVersion(soc_version, kSocVersionLen);
|
||||||
|
if (ret != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(EXCEPTION) << "GetSocVersion failed.";
|
||||||
|
}
|
||||||
|
// Get soc version from env value.
|
||||||
|
const char *soc_version_env = nullptr;
|
||||||
|
std::string str_soc_version_env = common::GetEnv(kSOC_VERSION);
|
||||||
|
if (!str_soc_version_env.empty()) {
|
||||||
|
soc_version_env = common::SafeCStr(str_soc_version_env);
|
||||||
|
}
|
||||||
|
if (soc_version_env != nullptr) {
|
||||||
|
if (std::strcmp(soc_version, soc_version_env) != 0) {
|
||||||
|
MS_LOG(DEBUG) << "Detected the env SOC_VERSION, so the SocVersion will be changed to " << str_soc_version_env
|
||||||
|
<< ".";
|
||||||
|
ret = rtSetSocVersion(soc_version_env);
|
||||||
|
if (ret != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(EXCEPTION) << "SetSocVersion failed, errorno: " << ret;
|
||||||
|
}
|
||||||
|
version = soc_version_env;
|
||||||
|
return soc_version_env;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
version = soc_version;
|
||||||
|
}
|
||||||
|
return version;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PlatformInfoInitialization() {
|
||||||
|
auto soc_version = GetSocVersion();
|
||||||
|
fe::PlatformInfo platform_info;
|
||||||
|
fe::OptionalInfo opti_compilation_info;
|
||||||
|
fe::PlatformInfoManager &inst = fe::PlatformInfoManager::Instance();
|
||||||
|
if (inst.InitializePlatformInfo() != 0) {
|
||||||
|
MS_LOG(WARNING) << "Initialize PlatformInfo failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (inst.GetPlatformInfo(soc_version, platform_info, opti_compilation_info) != 0) {
|
||||||
|
MS_LOG(WARNING) << "GetPlatformInfo failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
opti_compilation_info.soc_version = soc_version;
|
||||||
|
inst.SetOptionalCompilationInfo(opti_compilation_info);
|
||||||
|
}
|
||||||
|
|
||||||
void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph, const device::DeviceContext *device_context) {
|
void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph, const device::DeviceContext *device_context) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
auto outputs = common::AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
auto outputs = common::AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_UTILS_H_
|
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_UTILS_H_
|
||||||
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_UTILS_H_
|
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_ASCEND_UTILS_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
#include "plugin/device/ascend/hal/hardware/ascend_device_context.h"
|
#include "plugin/device/ascend/hal/hardware/ascend_device_context.h"
|
||||||
#include "backend/common/session/kernel_graph.h"
|
#include "backend/common/session/kernel_graph.h"
|
||||||
|
|
||||||
|
@ -30,6 +31,9 @@ void SetErrorManagerContext();
|
||||||
bool IsGraphMode();
|
bool IsGraphMode();
|
||||||
bool IsDynamicShapeGraph(const FuncGraphPtr &func_graph);
|
bool IsDynamicShapeGraph(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
|
std::string GetSocVersion();
|
||||||
|
void PlatformInfoInitialization();
|
||||||
|
|
||||||
// Some NOP nodes have be hide in execution order, it doesn't have output device address, this function creates
|
// Some NOP nodes have be hide in execution order, it doesn't have output device address, this function creates
|
||||||
// output device address for these nodes, and the output device address is the same with input device address.
|
// output device address for these nodes, and the output device address is the same with input device address.
|
||||||
void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph, const device::DeviceContext *device_context);
|
void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph, const device::DeviceContext *device_context);
|
||||||
|
|
|
@ -41,6 +41,7 @@
|
||||||
#include "mindspore/ccsrc/include/common/debug/common.h"
|
#include "mindspore/ccsrc/include/common/debug/common.h"
|
||||||
#include "kernel/common_utils.h"
|
#include "kernel/common_utils.h"
|
||||||
#include "mindspore/core/utils/file_utils.h"
|
#include "mindspore/core/utils/file_utils.h"
|
||||||
|
#include "plugin/device/ascend/hal/hardware/ascend_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
@ -48,7 +49,6 @@ namespace tbe {
|
||||||
constexpr auto kCceKernelMeta = "kernel_meta/";
|
constexpr auto kCceKernelMeta = "kernel_meta/";
|
||||||
constexpr auto kJsonSuffix = ".json";
|
constexpr auto kJsonSuffix = ".json";
|
||||||
constexpr auto kInfoSuffix = ".info";
|
constexpr auto kInfoSuffix = ".info";
|
||||||
constexpr auto kSOC_VERSION = "SOC_VERSION";
|
|
||||||
constexpr auto kBuildRes = "build_result";
|
constexpr auto kBuildRes = "build_result";
|
||||||
constexpr auto kTUNE_BANK_PATH = "TUNE_BANK_PATH";
|
constexpr auto kTUNE_BANK_PATH = "TUNE_BANK_PATH";
|
||||||
constexpr auto kTUNE_DUMP_PATH = "TUNE_DUMP_PATH";
|
constexpr auto kTUNE_DUMP_PATH = "TUNE_DUMP_PATH";
|
||||||
|
@ -148,7 +148,7 @@ nlohmann::json TbeUtils::GenSocInfo() {
|
||||||
soc_info_json["l1Fusion"] = "false";
|
soc_info_json["l1Fusion"] = "false";
|
||||||
soc_info_json["l2Fusion"] = "false";
|
soc_info_json["l2Fusion"] = "false";
|
||||||
soc_info_json["op_bank_update"] = false;
|
soc_info_json["op_bank_update"] = false;
|
||||||
soc_info_json["socVersion"] = GetSocVersion();
|
soc_info_json["socVersion"] = device::ascend::GetSocVersion();
|
||||||
soc_info_json["offlineTune"] = CheckOfflineTune();
|
soc_info_json["offlineTune"] = CheckOfflineTune();
|
||||||
soc_info_json["op_debug_dir"] = GetOpDebugPath();
|
soc_info_json["op_debug_dir"] = GetOpDebugPath();
|
||||||
soc_info_json["op_debug_level"] = GetOpDebugLevel();
|
soc_info_json["op_debug_level"] = GetOpDebugLevel();
|
||||||
|
@ -472,39 +472,6 @@ bool TbeUtils::CheckOfflineTune() {
|
||||||
return offline;
|
return offline;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string TbeUtils::GetSocVersion() {
|
|
||||||
// Get default soc version.
|
|
||||||
static std::string version;
|
|
||||||
if (version.empty()) {
|
|
||||||
const int kSocVersionLen = 50;
|
|
||||||
char soc_version[kSocVersionLen] = {0};
|
|
||||||
auto ret = rtGetSocVersion(soc_version, kSocVersionLen);
|
|
||||||
if (ret != RT_ERROR_NONE) {
|
|
||||||
MS_LOG(EXCEPTION) << "GetSocVersion failed.";
|
|
||||||
}
|
|
||||||
// Get soc version from env value.
|
|
||||||
const char *soc_version_env = nullptr;
|
|
||||||
std::string str_soc_version_env = common::GetEnv(kSOC_VERSION);
|
|
||||||
if (!str_soc_version_env.empty()) {
|
|
||||||
soc_version_env = common::SafeCStr(str_soc_version_env);
|
|
||||||
}
|
|
||||||
if (soc_version_env != nullptr) {
|
|
||||||
if (std::strcmp(soc_version, soc_version_env) != 0) {
|
|
||||||
MS_LOG(DEBUG) << "Detected the env SOC_VERSION, so the SocVersion will be changed to " << str_soc_version_env
|
|
||||||
<< ".";
|
|
||||||
ret = rtSetSocVersion(soc_version_env);
|
|
||||||
if (ret != RT_ERROR_NONE) {
|
|
||||||
MS_LOG(EXCEPTION) << "SetSocVersion failed, errorno: " << ret;
|
|
||||||
}
|
|
||||||
version = soc_version_env;
|
|
||||||
return soc_version_env;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
version = soc_version;
|
|
||||||
}
|
|
||||||
return version;
|
|
||||||
}
|
|
||||||
|
|
||||||
KernelPackPtr KernelMeta::LoadFromFile(const std::string &kernel_name) const {
|
KernelPackPtr KernelMeta::LoadFromFile(const std::string &kernel_name) const {
|
||||||
auto config_path = TbeUtils::GetOpDebugPath();
|
auto config_path = TbeUtils::GetOpDebugPath();
|
||||||
std::string cce_json = config_path + kCceKernelMeta + kernel_name + kJsonSuffix;
|
std::string cce_json = config_path + kCceKernelMeta + kernel_name + kJsonSuffix;
|
||||||
|
|
|
@ -56,8 +56,6 @@ class TbeUtils {
|
||||||
|
|
||||||
static nlohmann::json GenSocInfo();
|
static nlohmann::json GenSocInfo();
|
||||||
|
|
||||||
static std::string GetSocVersion();
|
|
||||||
|
|
||||||
static std::string GetOpDebugPath();
|
static std::string GetOpDebugPath();
|
||||||
|
|
||||||
static std::string GetBankPath();
|
static std::string GetBankPath();
|
||||||
|
|
|
@ -29,60 +29,89 @@
|
||||||
#include "runtime/device/kernel_info.h"
|
#include "runtime/device/kernel_info.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "plugin/device/ascend/optimizer/optimizer_factory.h"
|
#include "plugin/device/ascend/optimizer/optimizer_factory.h"
|
||||||
|
#include "common/util/platform_info.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kFloat16Len = 2; // size of float16;
|
constexpr size_t kMultiply2 = 2;
|
||||||
constexpr size_t kTopkIndexK = 1;
|
constexpr size_t kTopkIndexK = 1;
|
||||||
constexpr auto kAttrSorted = "sorted";
|
constexpr auto kAttrSorted = "sorted";
|
||||||
|
|
||||||
tensor::TensorPtr CreateTensor() {
|
tensor::TensorPtr ConstructAssistTensor(size_t assist_len, bool is_segment_sort = false, bool is_int32 = false) {
|
||||||
// 1 create tensor
|
// create tensor
|
||||||
const size_t last_dim = 4096;
|
int64_t shape_len = is_segment_sort ? SizeToLong(assist_len) : SizeToLong(assist_len * kMultiply2);
|
||||||
std::vector<int64_t> indices_shape = {SizeToLong(last_dim * 2)};
|
std::vector<int64_t> assist_shape{shape_len};
|
||||||
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16);
|
auto dtype = is_int32 ? kInt32 : kFloat16;
|
||||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
TensorTypePtr tensor_type = std::make_shared<TensorType>(dtype);
|
||||||
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type};
|
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type};
|
||||||
tensor::TensorPtr indices_tensor = std::make_shared<tensor::Tensor>(kFloat16->type_id(), indices_shape);
|
tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(dtype->type_id(), assist_shape);
|
||||||
MS_EXCEPTION_IF_NULL(indices_tensor);
|
assist_tensor->set_device_info(device_info);
|
||||||
indices_tensor->set_device_info(device_info);
|
|
||||||
|
|
||||||
// 2 set value of tensor
|
// set value of tensor
|
||||||
auto data_ptr = indices_tensor->data_c();
|
auto data_ptr = assist_tensor->data_c();
|
||||||
MS_EXCEPTION_IF_NULL(data_ptr);
|
MS_EXCEPTION_IF_NULL(data_ptr);
|
||||||
std::vector<float16> half_data;
|
if (is_int32) {
|
||||||
for (size_t i = 0; i < last_dim; ++i) {
|
auto data = static_cast<int32_t *>(data_ptr);
|
||||||
(void)half_data.emplace_back(float16(static_cast<float>(i)));
|
for (int32_t i = 0; i < SizeToInt(assist_len); ++i) {
|
||||||
|
*data = i;
|
||||||
|
++data;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto data = static_cast<float16 *>(data_ptr);
|
||||||
|
for (size_t i = 0; i < assist_len; ++i) {
|
||||||
|
*data = float16(static_cast<float>(i));
|
||||||
|
++data;
|
||||||
|
}
|
||||||
|
if (!is_segment_sort) {
|
||||||
|
for (size_t i = 0; i < assist_len; ++i) {
|
||||||
|
auto gap = static_cast<int>(i) - static_cast<int>(float16(static_cast<float>(i)));
|
||||||
|
*data = float16(static_cast<float>(gap));
|
||||||
|
++data;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < last_dim; ++i) {
|
|
||||||
auto gap = static_cast<int>(i) - static_cast<int>(float16(static_cast<float>(i)));
|
return assist_tensor;
|
||||||
(void)half_data.emplace_back(float16(static_cast<float>(gap)));
|
|
||||||
}
|
|
||||||
auto elem_num = last_dim * kFloat16Len * 2;
|
|
||||||
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(indices_tensor->data().nbytes()),
|
|
||||||
static_cast<void *>(half_data.data()), elem_num);
|
|
||||||
if (ret_code != 0) {
|
|
||||||
MS_LOG(ERROR) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return indices_tensor;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ValueNodePtr CreateValueNode() {
|
tensor::TensorPtr CreateAssistTensor(const std::vector<int64_t> &input_shape, int32_t k_num,
|
||||||
tensor::TensorPtr indices_tensor = CreateTensor();
|
const fe::PlatformInfo &platform_info, const fe::OptionalInfo &optional_info) {
|
||||||
MS_EXCEPTION_IF_NULL(indices_tensor);
|
bool is_lhisi = optional_info.soc_version.find("Hi3796CV300CS") != std::string::npos ||
|
||||||
auto indices_const = std::make_shared<ValueNode>(indices_tensor);
|
optional_info.soc_version.find("Hi3796CV300ES") != std::string::npos ||
|
||||||
MS_EXCEPTION_IF_NULL(indices_const);
|
optional_info.soc_version.find("SD3403") != std::string::npos;
|
||||||
auto indices_abstract = indices_tensor->ToAbstract();
|
constexpr int64_t kLhisiMaxLastSize = 3000;
|
||||||
indices_const->set_abstract(indices_abstract);
|
constexpr int64_t kHisiMaxLastSize = 5000;
|
||||||
auto indices_kernel_info = std::make_shared<device::KernelInfo>();
|
constexpr int64_t kLhisiMaxKNum = 2048;
|
||||||
MS_EXCEPTION_IF_NULL(indices_kernel_info);
|
constexpr int64_t kHisiMaxKNum = 4096;
|
||||||
indices_const->set_kernel_info(indices_kernel_info);
|
constexpr size_t kSmallSceneAssistLen = 4096;
|
||||||
|
constexpr size_t kLargeSceneAssistLen = 2048;
|
||||||
|
int64_t max_last_size = is_lhisi ? kLhisiMaxLastSize : kHisiMaxLastSize;
|
||||||
|
int64_t max_k_num = is_lhisi ? kLhisiMaxKNum : kHisiMaxKNum;
|
||||||
|
if (input_shape.back() > max_last_size || k_num > max_k_num) {
|
||||||
|
if (platform_info.str_info.short_soc_version == "Ascend910B" ||
|
||||||
|
platform_info.str_info.short_soc_version == "Ascend310B") {
|
||||||
|
return ConstructAssistTensor(kLargeSceneAssistLen, true, true);
|
||||||
|
} else {
|
||||||
|
return ConstructAssistTensor(kLargeSceneAssistLen, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ConstructAssistTensor(kSmallSceneAssistLen);
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueNodePtr CreateAssistNode(const std::vector<int64_t> &input_shape, int32_t k_num,
|
||||||
|
const fe::PlatformInfo &platform_info, const fe::OptionalInfo &optional_info) {
|
||||||
|
tensor::TensorPtr assist_tensor = CreateAssistTensor(input_shape, k_num, platform_info, optional_info);
|
||||||
|
MS_EXCEPTION_IF_NULL(assist_tensor);
|
||||||
|
auto assist_const = std::make_shared<ValueNode>(assist_tensor);
|
||||||
|
auto assist_abstract = assist_tensor->ToAbstract();
|
||||||
|
assist_const->set_abstract(assist_abstract);
|
||||||
|
auto assist_kernel_info = std::make_shared<device::KernelInfo>();
|
||||||
|
assist_const->set_kernel_info(assist_kernel_info);
|
||||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
|
||||||
builder1.SetOutputsFormat({kOpFormat_DEFAULT});
|
builder1.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||||
builder1.SetOutputsDeviceType({kNumberTypeFloat16});
|
builder1.SetOutputsDeviceType({common::AnfAlgo::GetOutputInferDataType(assist_const, 0)});
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), indices_const.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), assist_const.get());
|
||||||
return indices_const;
|
return assist_const;
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel::KernelBuildInfoPtr CreateKernelBuildInfo() {
|
kernel::KernelBuildInfoPtr CreateKernelBuildInfo() {
|
||||||
|
@ -106,13 +135,13 @@ bool CheckInputNamesSize(const CNodePtr &cnode) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CheckOutputShape(const AnfNodePtr &node) {
|
bool CheckInputShape(const AnfNodePtr &node) {
|
||||||
auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||||
if (shape.empty()) {
|
if (shape.empty()) {
|
||||||
MS_LOG(INFO) << "The output shape of topk to split must not be empty";
|
MS_LOG(INFO) << "The input shape of topk to split must not be empty";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto last_dim = shape[shape.size() - 1];
|
auto last_dim = shape.back();
|
||||||
const int64_t kMaxFloat16 = 65500;
|
const int64_t kMaxFloat16 = 65500;
|
||||||
if (last_dim > kMaxFloat16) {
|
if (last_dim > kMaxFloat16) {
|
||||||
MS_LOG(INFO) << "The last dim is more than " << kMaxFloat16 << ", switch to aicpu ops.";
|
MS_LOG(INFO) << "The last dim is more than " << kMaxFloat16 << ", switch to aicpu ops.";
|
||||||
|
@ -132,13 +161,13 @@ bool CheckInputType(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CheckFusion(const CNodePtr &node) {
|
bool CheckFusion(const CNodePtr &node) {
|
||||||
if (!common::AnfAlgo::HasNodeAttr(kAttrSorted, node) || !common::AnfAlgo::GetNodeAttr<bool>(node, kAttrSorted)) {
|
if (common::AnfAlgo::HasNodeAttr(kAttrSorted, node) && !common::AnfAlgo::GetNodeAttr<bool>(node, kAttrSorted)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!CheckInputNamesSize(node)) {
|
if (!CheckInputNamesSize(node)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!CheckOutputShape(node)) {
|
if (!CheckInputShape(node)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!CheckInputType(node)) {
|
if (!CheckInputType(node)) {
|
||||||
|
@ -182,28 +211,37 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
|
||||||
if (!IsValueNode<tensor::Tensor>(input_k)) {
|
if (!IsValueNode<tensor::Tensor>(input_k)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fe::PlatformInfo platform_info;
|
||||||
|
fe::OptionalInfo optional_info;
|
||||||
|
if (fe::PlatformInfoManager::Instance().GetPlatformInfoWithOutSocVersion(platform_info, optional_info)) {
|
||||||
|
MS_LOG(WARNING) << "Get platform info failed, quit fusion.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
ValuePtr value = GetValueNode(input_k);
|
ValuePtr value = GetValueNode(input_k);
|
||||||
MS_EXCEPTION_IF_NULL(value);
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
auto tensor = value->cast<tensor::TensorPtr>();
|
auto tensor = value->cast<tensor::TensorPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(tensor);
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
auto *data = static_cast<int32_t *>(tensor->data_c());
|
auto *data = static_cast<int32_t *>(tensor->data_c());
|
||||||
MS_EXCEPTION_IF_NULL(data);
|
MS_EXCEPTION_IF_NULL(data);
|
||||||
auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data));
|
int32_t k_num = *data;
|
||||||
|
auto new_value_node = std::make_shared<ValueNode>(MakeValue(k_num));
|
||||||
new_cnode->set_input(kTopkIndexK + 1, new_value_node);
|
new_cnode->set_input(kTopkIndexK + 1, new_value_node);
|
||||||
|
|
||||||
mindspore::HashSet<size_t> attr_index{kTopkIndexK};
|
mindspore::HashSet<size_t> attr_index{kTopkIndexK};
|
||||||
ConstInputToAttr(new_cnode, attr_index);
|
ConstInputToAttr(new_cnode, attr_index);
|
||||||
auto indices_const = CreateValueNode();
|
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(new_cnode, 0);
|
||||||
new_cnode->add_input(indices_const);
|
auto assist_const = CreateAssistNode(input_shape, k_num, platform_info, optional_info);
|
||||||
|
new_cnode->add_input(assist_const);
|
||||||
MS_EXCEPTION_IF_NULL(supported_checker_);
|
MS_EXCEPTION_IF_NULL(supported_checker_);
|
||||||
if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) {
|
if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) {
|
||||||
MS_LOG(INFO) << "split topk failed, check to aicpu.";
|
MS_LOG(INFO) << "Split topk failed, check to aicpu.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (kernel_graph != nullptr) {
|
if (kernel_graph != nullptr) {
|
||||||
MS_LOG(INFO) << "split topk success. use tbe aicore.";
|
MS_LOG(INFO) << "Split topk success. use tbe aicore.";
|
||||||
kernel_graph->AddValueNodeToGraph(indices_const);
|
kernel_graph->AddValueNodeToGraph(assist_const);
|
||||||
}
|
}
|
||||||
|
|
||||||
return new_cnode;
|
return new_cnode;
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "common/util/platform_info.h"
|
||||||
|
|
||||||
|
namespace fe {
|
||||||
|
PlatformInfoManager &PlatformInfoManager::Instance() {
|
||||||
|
static PlatformInfoManager instance{};
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t PlatformInfoManager::InitializePlatformInfo() { return 0; }
|
||||||
|
|
||||||
|
uint32_t PlatformInfoManager::Finalize() { return 0; }
|
||||||
|
|
||||||
|
uint32_t PlatformInfoManager::GetPlatformInfo(const std::string soc_version, PlatformInfo &platform_info,
|
||||||
|
OptionalInfo &optional_info) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t PlatformInfoManager::GetPlatformInfoWithOutSocVersion(PlatformInfo &platform_info,
|
||||||
|
OptionalInfo &optional_info) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PlatformInfoManager::SetOptionalCompilationInfo(OptionalInfo &optional_info) {}
|
||||||
|
|
||||||
|
uint32_t PlatformInfoManager::GetPlatformInfos(const std::string soc_version, PlatFormInfos &platform_info,
|
||||||
|
OptionalInfos &optional_info) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t PlatformInfoManager::GetPlatformInfoWithOutSocVersion(PlatFormInfos &platform_infos,
|
||||||
|
OptionalInfos &optional_infos) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PlatformInfoManager::SetOptionalCompilationInfo(OptionalInfos &optional_infos) {}
|
||||||
|
|
||||||
|
PlatformInfoManager::PlatformInfoManager() {}
|
||||||
|
|
||||||
|
PlatformInfoManager::~PlatformInfoManager() {}
|
||||||
|
} // namespace fe
|
Loading…
Reference in New Issue