!27610 Support Ascend710 device to run converter and benchmark.

Merge pull request !27610 from chenping/master
This commit is contained in:
i-robot 2021-12-20 12:20:38 +00:00 committed by Gitee
commit 5503948658
20 changed files with 152 additions and 140 deletions

View File

@ -308,7 +308,7 @@ void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
} }
std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is /// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend. This option is
/// invalid for MindSpore Lite. /// invalid for MindSpore Lite.
class MS_API AscendDeviceInfo : public DeviceInfoContext { class MS_API AscendDeviceInfo : public DeviceInfoContext {
public: public:

View File

@ -212,7 +212,7 @@ class MS_API Model {
/// \brief Inference model. /// \brief Inference model.
/// ///
/// \param[in] device_type Device typeoptions are kGPU, kAscend, kAscend910, etc. /// \param[in] device_type Device typeoptions are kGPU, kAscend etc.
/// \param[in] model_type The type of model file, options are ModelType::kMindIR, ModelType::kOM. /// \param[in] model_type The type of model file, options are ModelType::kMindIR, ModelType::kOM.
/// ///
/// \return Is supported or not. /// \return Is supported or not.

View File

@ -33,15 +33,15 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
if (device_infos.size() != 1) { if (device_infos.size() != 1) {
return; return;
} }
auto ascend310_info = device_infos[0]->Cast<AscendDeviceInfo>(); auto ascend_info = device_infos[0]->Cast<AscendDeviceInfo>();
if (ascend310_info == nullptr) { if (ascend_info == nullptr) {
return; return;
} }
insert_op_cfg_path_ = ascend310_info->GetInsertOpConfigPath(); insert_op_cfg_path_ = ascend_info->GetInsertOpConfigPath();
input_format_ = ascend310_info->GetInputFormat(); input_format_ = ascend_info->GetInputFormat();
input_shape_map_ = ascend310_info->GetInputShapeMap(); input_shape_map_ = ascend_info->GetInputShapeMap();
auto out_type = ascend310_info->GetOutputType(); auto out_type = ascend_info->GetOutputType();
auto iter = kSupportedDtypeOptionMap.find(out_type); auto iter = kSupportedDtypeOptionMap.find(out_type);
if (out_type == DataType::kTypeUnknown) { if (out_type == DataType::kTypeUnknown) {
// do nothing // do nothing
@ -50,13 +50,13 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
} else { } else {
output_type_ = iter->second; output_type_ = iter->second;
} }
dynamic_batch_size_ = ascend310_info->GetDynamicBatchSize(); dynamic_batch_size_ = ascend_info->GetDynamicBatchSize();
dynamic_image_size_ = ascend310_info->GetDynamicImageSize(); dynamic_image_size_ = ascend_info->GetDynamicImageSize();
precision_mode_ = ascend310_info->GetPrecisionMode(); precision_mode_ = ascend_info->GetPrecisionMode();
op_select_impl_mode_ = ascend310_info->GetOpSelectImplMode(); op_select_impl_mode_ = ascend_info->GetOpSelectImplMode();
fusion_switch_cfg_path_ = ascend310_info->GetFusionSwitchConfigPath(); fusion_switch_cfg_path_ = ascend_info->GetFusionSwitchConfigPath();
device_id_ = ascend310_info->GetDeviceID(); device_id_ = ascend_info->GetDeviceID();
buffer_optimize_mode_ = ascend310_info->GetBufferOptimizeMode(); buffer_optimize_mode_ = ascend_info->GetBufferOptimizeMode();
const char *soc_name = aclrtGetSocName(); const char *soc_name = aclrtGetSocName();
if (soc_name == nullptr) { if (soc_name == nullptr) {
MS_LOG(WARNING) << "Get soc version failed."; MS_LOG(WARNING) << "Get soc version failed.";

View File

@ -44,7 +44,7 @@ typedef struct NpuDeviceInfo {
int frequency_ = 3; /**< npu frequency inference, low 1, medium 2, high 3, extreme 4, other values will be set to 3 */ int frequency_ = 3; /**< npu frequency inference, low 1, medium 2, high 3, extreme 4, other values will be set to 3 */
} NpuDeviceInfo; } NpuDeviceInfo;
/// \brief Ascend310DeviceInfo defined for Ascend's configuration information. /// \brief AscendDeviceInfo defined for Ascend's configuration information.
typedef struct AscendDeviceInfo { typedef struct AscendDeviceInfo {
uint32_t device_id_; uint32_t device_id_;
std::string batch_size_; std::string batch_size_;
@ -55,7 +55,7 @@ struct DeviceInfo {
CpuDeviceInfo cpu_device_info_; CpuDeviceInfo cpu_device_info_;
GpuDeviceInfo gpu_device_info_; GpuDeviceInfo gpu_device_info_;
NpuDeviceInfo npu_device_info_; NpuDeviceInfo npu_device_info_;
AscendDeviceInfo ascend310_device_info_; AscendDeviceInfo ascend_device_info_;
}; };
/// \brief DeviceContext defined for holding backend's configuration information. /// \brief DeviceContext defined for holding backend's configuration information.

View File

@ -27,10 +27,10 @@ typedef enum {
/// \brief DeviceType defined for holding user's preferred backend. /// \brief DeviceType defined for holding user's preferred backend.
typedef enum { typedef enum {
DT_CPU, /**< CPU device type */ DT_CPU, /**< CPU device type */
DT_GPU, /**< GPU device type */ DT_GPU, /**< GPU device type */
DT_NPU, /**< NPU device type */ DT_NPU, /**< NPU device type */
DT_ASCEND310 /**< ASCEND310 device type */ DT_ASCEND /**< ASCEND device type */
} DeviceType; } DeviceType;
typedef enum { typedef enum {

View File

@ -276,6 +276,7 @@ add_library(lite_src_mid OBJECT ${LITE_SRC})
add_dependencies(lite_src_mid fbs_src) add_dependencies(lite_src_mid fbs_src)
if(MSLITE_ENABLE_ACL) if(MSLITE_ENABLE_ACL)
include_directories(${TOP_DIR}/graphengine/inc/external)
add_subdirectory(runtime/kernel/ascend310) add_subdirectory(runtime/kernel/ascend310)
link_directories(${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) link_directories(${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
endif() endif()

View File

@ -88,22 +88,22 @@ std::vector<size_t> GetBatchSize(const std::string &batch_size) {
return res; return res;
} }
std::shared_ptr<mindspore::Ascend310DeviceInfo> Ascend310DeviceInfoFromAscend310DeviceContext( std::shared_ptr<mindspore::AscendDeviceInfo> AscendDeviceInfoFromAscendDeviceContext(
const lite::DeviceContext &ascend310_context) { const lite::DeviceContext &ascend_context) {
if (ascend310_context.device_type_ != DT_ASCEND310) { if (ascend_context.device_type_ != DT_ASCEND) {
MS_LOG(ERROR) << "Function input parameter is not ascend310 context."; MS_LOG(ERROR) << "Function input parameter is not ascend context.";
return nullptr; return nullptr;
} }
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>(); auto ascend_info = std::make_shared<mindspore::AscendDeviceInfo>();
MS_CHECK_TRUE_RET(ascend310_info != nullptr, nullptr); MS_CHECK_TRUE_RET(ascend_info != nullptr, nullptr);
ascend310_info->SetDeviceID(ascend310_context.device_info_.ascend310_device_info_.device_id_); ascend_info->SetDeviceID(ascend_context.device_info_.ascend_device_info_.device_id_);
std::string batch_size = ascend310_context.device_info_.ascend310_device_info_.batch_size_; std::string batch_size = ascend_context.device_info_.ascend_device_info_.batch_size_;
if (!batch_size.empty()) { if (!batch_size.empty()) {
auto val = GetBatchSize(batch_size); auto val = GetBatchSize(batch_size);
ascend310_info->SetDynamicBatchSize(val); ascend_info->SetDynamicBatchSize(val);
} }
ascend310_info->SetDynamicImageSize(ascend310_context.device_info_.ascend310_device_info_.image_size_); ascend_info->SetDynamicImageSize(ascend_context.device_info_.ascend_device_info_.image_size_);
return ascend310_info; return ascend_info;
} }
} // namespace } // namespace
@ -126,7 +126,7 @@ mindspore::Context *MSContextFromContext(const lite::Context *context) {
transfer_funcs = {{DT_CPU, CPUDeviceInfoFromCPUDeviceContext}, transfer_funcs = {{DT_CPU, CPUDeviceInfoFromCPUDeviceContext},
{DT_GPU, GPUDeviceInfoFromGPUDeviceContext}, {DT_GPU, GPUDeviceInfoFromGPUDeviceContext},
{DT_NPU, NPUDeviceInfoFromNPUDeviceContext}, {DT_NPU, NPUDeviceInfoFromNPUDeviceContext},
{DT_ASCEND310, Ascend310DeviceInfoFromAscend310DeviceContext}}; {DT_ASCEND, AscendDeviceInfoFromAscendDeviceContext}};
for (auto &device_context : context->device_list_) { for (auto &device_context : context->device_list_) {
auto device_type = device_context.device_type_; auto device_type = device_context.device_type_;
if (transfer_funcs.find(device_type) == transfer_funcs.end()) { if (transfer_funcs.find(device_type) == transfer_funcs.end()) {

View File

@ -37,18 +37,18 @@ constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequ
constexpr auto kModelOptionProvider = "mindspore.option.provider"; constexpr auto kModelOptionProvider = "mindspore.option.provider";
constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device"; constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id"; constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID; constexpr auto kModelOptionAscendDeviceID = kModelOptionDeviceID;
constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path"; constexpr auto kModelOptionAscendInsertOpCfgPath = "mindspore.option.ascend.insert_op_config_file_path";
constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format"; constexpr auto kModelOptionAscendInputFormat = "mindspore.option.ascend.input_format";
constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map"; constexpr auto kModelOptionAscendInputShapeMap = "mindspore.option.ascend.input_shape_map";
constexpr auto kModelOptionAscend310InputShape = "mindspore.option.ascend310.input_shape"; constexpr auto kModelOptionAscendInputShape = "mindspore.option.ascend.input_shape";
constexpr auto kModelOptionAscend310OutputType = "mindspore.option.ascend310.output_type"; constexpr auto kModelOptionAscendOutputType = "mindspore.option.ascend.output_type";
constexpr auto kModelOptionAscend310PrecisionMode = "mindspore.option.ascend310.precision_mode"; constexpr auto kModelOptionAscendPrecisionMode = "mindspore.option.ascend.precision_mode";
constexpr auto kModelOptionAscend310OpSelectImplMode = "mindspore.option.ascend310.op_select_impl_mode"; constexpr auto kModelOptionAscendOpSelectImplMode = "mindspore.option.ascend.op_select_impl_mode";
constexpr auto KModelOptionAscend310FusionSwitchCfgPath = "mindspore.option.ascend310.fusion_switch_config_file_path"; constexpr auto KModelOptionAscendFusionSwitchCfgPath = "mindspore.option.ascend.fusion_switch_config_file_path";
constexpr auto kModelOptionAscend310DynamicBatchSize = "mindspore.option.ascend310.dynamic_batch_size"; constexpr auto kModelOptionAscendDynamicBatchSize = "mindspore.option.ascend.dynamic_batch_size";
constexpr auto kModelOptionAscend310DynamicImageSize = "mindspore.option.ascend310.dynamic_image_size"; constexpr auto kModelOptionAscendDynamicImageSize = "mindspore.option.ascend.dynamic_image_size";
constexpr auto kModelOptionAscend310BufferOptimize = "mindspore.option.ascend310.buffer_optimize"; constexpr auto kModelOptionAscendBufferOptimize = "mindspore.option.ascend.buffer_optimize";
Context::Context() : data_(std::make_shared<Data>()) {} Context::Context() : data_(std::make_shared<Data>()) {}
@ -372,7 +372,7 @@ void AscendDeviceInfo::SetDeviceID(uint32_t device_id) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
} }
data_->params[kModelOptionAscend310DeviceID] = device_id; data_->params[kModelOptionAscendDeviceID] = device_id;
} }
uint32_t AscendDeviceInfo::GetDeviceID() const { uint32_t AscendDeviceInfo::GetDeviceID() const {
@ -380,7 +380,7 @@ uint32_t AscendDeviceInfo::GetDeviceID() const {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return 0; return 0;
} }
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID); return GetValue<uint32_t>(data_, kModelOptionAscendDeviceID);
} }
void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) { void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
@ -388,14 +388,14 @@ void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path)
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
} }
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path); data_->params[kModelOptionAscendInsertOpCfgPath] = CharToString(cfg_path);
} }
std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const { std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
} }
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InsertOpCfgPath); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendInsertOpCfgPath);
return StringToChar(ref); return StringToChar(ref);
} }
@ -404,7 +404,7 @@ void AscendDeviceInfo::SetInputFormat(const std::vector<char> &format) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
} }
data_->params[kModelOptionAscend310InputFormat] = CharToString(format); data_->params[kModelOptionAscendInputFormat] = CharToString(format);
} }
std::vector<char> AscendDeviceInfo::GetInputFormatChar() const { std::vector<char> AscendDeviceInfo::GetInputFormatChar() const {
@ -412,7 +412,7 @@ std::vector<char> AscendDeviceInfo::GetInputFormatChar() const {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
} }
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputFormat); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendInputFormat);
return StringToChar(ref); return StringToChar(ref);
} }
@ -421,14 +421,14 @@ void AscendDeviceInfo::SetInputShape(const std::vector<char> &shape) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
} }
data_->params[kModelOptionAscend310InputShape] = CharToString(shape); data_->params[kModelOptionAscendInputShape] = CharToString(shape);
} }
std::vector<char> AscendDeviceInfo::GetInputShapeChar() const { std::vector<char> AscendDeviceInfo::GetInputShapeChar() const {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
} }
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputShape); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendInputShape);
return StringToChar(ref); return StringToChar(ref);
} }
@ -444,7 +444,7 @@ void AscendDeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_ba
} }
batchs += std::to_string(dynamic_batch_size[i]); batchs += std::to_string(dynamic_batch_size[i]);
} }
data_->params[kModelOptionAscend310DynamicBatchSize] = batchs; data_->params[kModelOptionAscendDynamicBatchSize] = batchs;
} }
std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const { std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const {
@ -452,7 +452,7 @@ std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
} }
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicBatchSize); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendDynamicBatchSize);
return StringToChar(ref); return StringToChar(ref);
} }
@ -461,7 +461,7 @@ void AscendDeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_imag
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
} }
data_->params[kModelOptionAscend310DynamicImageSize] = CharToString(dynamic_image_size); data_->params[kModelOptionAscendDynamicImageSize] = CharToString(dynamic_image_size);
} }
std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const { std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const {
@ -469,7 +469,7 @@ std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
} }
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicImageSize); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendDynamicImageSize);
return StringToChar(ref); return StringToChar(ref);
} }
@ -478,7 +478,7 @@ void AscendDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode)
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
} }
data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode); data_->params[kModelOptionAscendPrecisionMode] = CharToString(precision_mode);
} }
std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const { std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const {
@ -486,7 +486,7 @@ std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
} }
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310PrecisionMode); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendPrecisionMode);
return StringToChar(ref); return StringToChar(ref);
} }
@ -495,7 +495,7 @@ void AscendDeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_im
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
} }
data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode); data_->params[kModelOptionAscendOpSelectImplMode] = CharToString(op_select_impl_mode);
} }
std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const { std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const {
@ -503,7 +503,7 @@ std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
} }
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310OpSelectImplMode); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendOpSelectImplMode);
return StringToChar(ref); return StringToChar(ref);
} }
@ -512,14 +512,14 @@ void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_pa
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
} }
data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path); data_->params[KModelOptionAscendFusionSwitchCfgPath] = CharToString(cfg_path);
} }
std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const { std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
} }
const std::string &ref = GetValue<std::string>(data_, KModelOptionAscend310FusionSwitchCfgPath); const std::string &ref = GetValue<std::string>(data_, KModelOptionAscendFusionSwitchCfgPath);
return StringToChar(ref); return StringToChar(ref);
} }
@ -528,7 +528,7 @@ void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &s
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
} }
data_->params[kModelOptionAscend310InputShapeMap] = shape; data_->params[kModelOptionAscendInputShapeMap] = shape;
} }
std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const { std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const {
@ -536,7 +536,7 @@ std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::map<int, std::vector<int>>(); return std::map<int, std::vector<int>>();
} }
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap); return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscendInputShapeMap);
} }
void AscendDeviceInfo::SetOutputType(enum DataType output_type) { void AscendDeviceInfo::SetOutputType(enum DataType output_type) {
@ -544,7 +544,7 @@ void AscendDeviceInfo::SetOutputType(enum DataType output_type) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
} }
data_->params[kModelOptionAscend310OutputType] = output_type; data_->params[kModelOptionAscendOutputType] = output_type;
} }
enum DataType AscendDeviceInfo::GetOutputType() const { enum DataType AscendDeviceInfo::GetOutputType() const {
@ -552,7 +552,7 @@ enum DataType AscendDeviceInfo::GetOutputType() const {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return DataType::kTypeUnknown; return DataType::kTypeUnknown;
} }
return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType); return GetValue<enum DataType>(data_, kModelOptionAscendOutputType);
} }
void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) { void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
@ -560,7 +560,7 @@ void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_opt
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
} }
data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode); data_->params[kModelOptionAscendBufferOptimize] = CharToString(buffer_optimize_mode);
} }
std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const { std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const {
@ -568,7 +568,7 @@ std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
} }
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310BufferOptimize); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendBufferOptimize);
return StringToChar(ref); return StringToChar(ref);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -60,12 +60,12 @@ Status ContextUtils::AddNpuDevice(int frequency, lite::InnerContext *inner_conte
return kSuccess; return kSuccess;
} }
Status ContextUtils::AddAscend310Device(lite::InnerContext *inner_context, DeviceInfoContext *device) { Status ContextUtils::AddAscendDevice(lite::InnerContext *inner_context, DeviceInfoContext *device) {
lite::DeviceInfo device_info = {0}; lite::DeviceInfo device_info = {0};
auto ascend310_context = device->Cast<Ascend310DeviceInfo>(); auto ascend_context = device->Cast<AscendDeviceInfo>();
device_info.ascend310_device_info_ = {ascend310_context->GetDeviceID(), ascend310_context->GetDynamicBatchSize(), device_info.ascend_device_info_ = {ascend_context->GetDeviceID(), ascend_context->GetDynamicBatchSize(),
ascend310_context->GetDynamicImageSize()}; ascend_context->GetDynamicImageSize()};
inner_context->device_list_.push_back({lite::DT_ASCEND310, device_info}); inner_context->device_list_.push_back({lite::DT_ASCEND, device_info});
return kSuccess; return kSuccess;
} }
@ -111,7 +111,7 @@ lite::InnerContext *ContextUtils::Convert(Context *context) {
auto npu_context = device->Cast<KirinNPUDeviceInfo>(); auto npu_context = device->Cast<KirinNPUDeviceInfo>();
ret = AddNpuDevice(npu_context->GetFrequency(), inner_context.get()); ret = AddNpuDevice(npu_context->GetFrequency(), inner_context.get());
} else if (device->GetDeviceType() == kAscend) { } else if (device->GetDeviceType() == kAscend) {
ret = AddAscend310Device(inner_context.get(), device.get()); ret = AddAscendDevice(inner_context.get(), device.get());
} }
if (ret != kSuccess) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Add device failed!"; MS_LOG(ERROR) << "Add device failed!";

View File

@ -43,7 +43,7 @@ class ContextUtils {
const std::string &provider_device, const std::shared_ptr<Allocator> &allocator, const std::string &provider_device, const std::shared_ptr<Allocator> &allocator,
lite::InnerContext *inner_context); lite::InnerContext *inner_context);
static Status AddNpuDevice(int frequency, lite::InnerContext *inner_context); static Status AddNpuDevice(int frequency, lite::InnerContext *inner_context);
static Status AddAscend310Device(lite::InnerContext *inner_context, DeviceInfoContext *device); static Status AddAscendDevice(lite::InnerContext *inner_context, DeviceInfoContext *device);
static bool IsAffinityModeValid(int affinity_mode) { static bool IsAffinityModeValid(int affinity_mode) {
return affinity_mode >= lite::NO_BIND && affinity_mode <= lite::MID_CPU; return affinity_mode >= lite::NO_BIND && affinity_mode <= lite::MID_CPU;
} }

View File

@ -18,6 +18,9 @@
#ifdef GPU_TENSORRT #ifdef GPU_TENSORRT
#include <cuda_runtime.h> #include <cuda_runtime.h>
#endif #endif
#ifdef ENABLE_LITE_ACL
#include "acl/acl_base.h"
#endif
#include <mutex> #include <mutex>
#include "include/api/callback/callback.h" #include "include/api/callback/callback.h"
#include "include/api/context.h" #include "include/api/context.h"
@ -153,8 +156,11 @@ Model::Model() : impl_(nullptr) {}
Model::~Model() {} Model::~Model() {}
bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) { bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) {
if (device_type == kGPU) { if (device_type == kCPU) {
return true;
}
#ifdef GPU_TENSORRT #ifdef GPU_TENSORRT
if (device_type == kGPU) {
int driver_version = 0; int driver_version = 0;
int ret = cudaDriverGetVersion(&driver_version); int ret = cudaDriverGetVersion(&driver_version);
if (ret != cudaSuccess || driver_version == 0) { if (ret != cudaSuccess || driver_version == 0) {
@ -162,18 +168,24 @@ bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type)
return false; return false;
} }
return true; return true;
#else
return false;
#endif
} else if (device_type == kCPU) {
#ifdef ENABLE_LITE_ACL
return false;
#else
return true;
#endif
} else {
return false;
} }
#endif
#ifdef ENABLE_LITE_ACL
if (device_type == kAscend || device_type == kAscend310) {
const char *soc_name_c = aclrtGetSocName();
if (soc_name_c == nullptr) {
MS_LOG(WARNING) << "aclrtGetSocName failed.";
return false;
}
std::string soc_name(soc_name_c);
if (soc_name.find("910") != std::string::npos) {
MS_LOG(WARNING) << "Device not support, aclrtGetSocName: " << soc_name;
return false;
}
return true;
}
#endif
return false;
} }
std::vector<MSTensor> Model::GetInputs() { std::vector<MSTensor> Model::GetInputs() {

View File

@ -279,7 +279,7 @@ bool InnerContext::IsProviderEnabled() const {
bool InnerContext::IsAllDeviceTypeValid() const { bool InnerContext::IsAllDeviceTypeValid() const {
return std::all_of(this->device_list_.begin(), this->device_list_.end(), [](const DeviceContext &device) { return std::all_of(this->device_list_.begin(), this->device_list_.end(), [](const DeviceContext &device) {
return device.device_type_ >= DT_CPU && device.device_type_ <= DT_ASCEND310; return device.device_type_ >= DT_CPU && device.device_type_ <= DT_ASCEND;
}); });
} }

View File

@ -34,17 +34,17 @@ constexpr auto kNCHWHeightIdx = 2;
constexpr auto kNCHWWidthIdx = 3; constexpr auto kNCHWWidthIdx = 3;
constexpr auto kImageSizeHwNum = 2; constexpr auto kImageSizeHwNum = 2;
} // namespace } // namespace
CustomAscend310Kernel::CustomAscend310Kernel(const std::vector<mindspore::MSTensor> &inputs, CustomAscendKernel::CustomAscendKernel(const std::vector<mindspore::MSTensor> &inputs,
const std::vector<mindspore::MSTensor> &outputs, const std::vector<mindspore::MSTensor> &outputs,
const schema::Primitive *primitive, const mindspore::Context *ctx) const schema::Primitive *primitive, const mindspore::Context *ctx)
: Kernel(inputs, outputs, primitive, ctx), : Kernel(inputs, outputs, primitive, ctx),
load_model_(false), load_model_(false),
acl_options_({}), acl_options_({}),
model_infer_(nullptr), model_infer_(nullptr),
InputDataIndex_(0) {} InputDataIndex_(0) {}
CustomAscend310Kernel::~CustomAscend310Kernel() { CustomAscendKernel::~CustomAscendKernel() {
if (model_infer_ != nullptr) { if (load_model_) {
int ret = model_infer_->Finalize(); int ret = model_infer_->Finalize();
if (ret != lite::RET_OK) { if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Model finalize failed."; MS_LOG(ERROR) << "Model finalize failed.";
@ -52,7 +52,7 @@ CustomAscend310Kernel::~CustomAscend310Kernel() {
} }
} }
STATUS CustomAscend310Kernel::PrepareModelInfer() { STATUS CustomAscendKernel::PrepareModelInfer() {
if (inputs_.size() < 1) { if (inputs_.size() < 1) {
MS_LOG(ERROR) << "Inputs size should not be less than 1."; MS_LOG(ERROR) << "Inputs size should not be less than 1.";
return lite::RET_ERROR; return lite::RET_ERROR;
@ -85,7 +85,7 @@ STATUS CustomAscend310Kernel::PrepareModelInfer() {
return lite::RET_OK; return lite::RET_OK;
} }
STATUS CustomAscend310Kernel::Prepare() { STATUS CustomAscendKernel::Prepare() {
if (load_model_) { if (load_model_) {
MS_LOG(INFO) << "Custom kernel has been prepared."; MS_LOG(INFO) << "Custom kernel has been prepared.";
return lite::RET_OK; return lite::RET_OK;
@ -100,7 +100,7 @@ STATUS CustomAscend310Kernel::Prepare() {
return lite::RET_OK; return lite::RET_OK;
} }
void CustomAscend310Kernel::RecordInputDataIndex() { void CustomAscendKernel::RecordInputDataIndex() {
for (size_t idx = 0; idx < inputs_.size(); ++idx) { for (size_t idx = 0; idx < inputs_.size(); ++idx) {
if (inputs_[idx].Data() == nullptr) { if (inputs_[idx].Data() == nullptr) {
InputDataIndex_ = idx; InputDataIndex_ = idx;
@ -109,14 +109,14 @@ void CustomAscend310Kernel::RecordInputDataIndex() {
} }
} }
STATUS CustomAscend310Kernel::ReSize() { STATUS CustomAscendKernel::ReSize() {
if (!load_model_) { if (!load_model_) {
return Prepare(); return Prepare();
} }
return lite::RET_OK; return lite::RET_OK;
} }
STATUS CustomAscend310Kernel::ProcDynamicInput(std::vector<mindspore::MSTensor> *inputs) { STATUS CustomAscendKernel::ProcDynamicInput(std::vector<mindspore::MSTensor> *inputs) {
if (acl_options_.batch_size.empty() && acl_options_.image_size.empty()) { if (acl_options_.batch_size.empty() && acl_options_.image_size.empty()) {
MS_LOG(INFO) << "Input is not dynamic mode."; MS_LOG(INFO) << "Input is not dynamic mode.";
return lite::RET_OK; return lite::RET_OK;
@ -154,7 +154,7 @@ STATUS CustomAscend310Kernel::ProcDynamicInput(std::vector<mindspore::MSTensor>
return lite::RET_OK; return lite::RET_OK;
} }
STATUS CustomAscend310Kernel::GetRealBatchSize(std::vector<mindspore::MSTensor> *inputs, int32_t *batch_size) { STATUS CustomAscendKernel::GetRealBatchSize(std::vector<mindspore::MSTensor> *inputs, int32_t *batch_size) {
CHECK_NULL_RETURN(batch_size); CHECK_NULL_RETURN(batch_size);
if (InputDataIndex_ >= inputs->size()) { if (InputDataIndex_ >= inputs->size()) {
MS_LOG(ERROR) << " Input data index " << InputDataIndex_ << " is larger than input size " << inputs->size(); MS_LOG(ERROR) << " Input data index " << InputDataIndex_ << " is larger than input size " << inputs->size();
@ -177,8 +177,8 @@ STATUS CustomAscend310Kernel::GetRealBatchSize(std::vector<mindspore::MSTensor>
return lite::RET_OK; return lite::RET_OK;
} }
STATUS CustomAscend310Kernel::GetRealImageSize(std::vector<mindspore::MSTensor> *inputs, int32_t *image_size, STATUS CustomAscendKernel::GetRealImageSize(std::vector<mindspore::MSTensor> *inputs, int32_t *image_size,
int32_t num) { int32_t num) {
CHECK_NULL_RETURN(image_size); CHECK_NULL_RETURN(image_size);
if (InputDataIndex_ >= inputs->size()) { if (InputDataIndex_ >= inputs->size()) {
MS_LOG(ERROR) << "Input data index " << InputDataIndex_ << " is larger than input size " << inputs->size(); MS_LOG(ERROR) << "Input data index " << InputDataIndex_ << " is larger than input size " << inputs->size();
@ -217,7 +217,7 @@ STATUS CustomAscend310Kernel::GetRealImageSize(std::vector<mindspore::MSTensor>
return lite::RET_OK; return lite::RET_OK;
} }
STATUS CustomAscend310Kernel::Execute() { STATUS CustomAscendKernel::Execute() {
if (!load_model_) { if (!load_model_) {
MS_LOG(WARNING) << "Custom kernel has not been prepared."; MS_LOG(WARNING) << "Custom kernel has not been prepared.";
return lite::RET_OK; return lite::RET_OK;
@ -246,7 +246,7 @@ std::shared_ptr<kernel::Kernel> CustomCreateKernel(const std::vector<mindspore::
return nullptr; return nullptr;
} }
auto kernel = std::make_shared<CustomAscend310Kernel>(inputs, outputs, primitive, ctx); auto kernel = std::make_shared<CustomAscendKernel>(inputs, outputs, primitive, ctx);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "New custom kernel is nullptr"; MS_LOG(ERROR) << "New custom kernel is nullptr";
return nullptr; return nullptr;
@ -262,8 +262,8 @@ const auto kFloat32 = DataType::kNumberTypeFloat32;
const auto kInt8 = DataType::kNumberTypeInt8; const auto kInt8 = DataType::kNumberTypeInt8;
const auto kUInt8 = DataType::kNumberTypeUInt8; const auto kUInt8 = DataType::kNumberTypeUInt8;
} // namespace } // namespace
REGISTER_CUSTOM_KERNEL(ASCEND310, ACL, kFloat32, ACL, kernel::acl::CustomCreateKernel) REGISTER_CUSTOM_KERNEL(ASCEND, ACL, kFloat32, ACL, kernel::acl::CustomCreateKernel)
REGISTER_CUSTOM_KERNEL(ASCEND310, ACL, kInt8, ACL, kernel::acl::CustomCreateKernel) REGISTER_CUSTOM_KERNEL(ASCEND, ACL, kInt8, ACL, kernel::acl::CustomCreateKernel)
REGISTER_CUSTOM_KERNEL(ASCEND310, ACL, kUInt8, ACL, kernel::acl::CustomCreateKernel) REGISTER_CUSTOM_KERNEL(ASCEND, ACL, kUInt8, ACL, kernel::acl::CustomCreateKernel)
} // namespace registry } // namespace registry
} // namespace mindspore } // namespace mindspore

View File

@ -31,11 +31,11 @@ namespace mindspore::kernel {
namespace acl { namespace acl {
using mindspore::lite::STATUS; using mindspore::lite::STATUS;
class CustomAscend310Kernel : public kernel::Kernel { class CustomAscendKernel : public kernel::Kernel {
public: public:
CustomAscend310Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs, CustomAscendKernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
const mindspore::schema::Primitive *primitive, const mindspore::Context *ctx); const mindspore::schema::Primitive *primitive, const mindspore::Context *ctx);
~CustomAscend310Kernel() override; ~CustomAscendKernel() override;
STATUS Prepare() override; STATUS Prepare() override;
STATUS ReSize() override; STATUS ReSize() override;

View File

@ -295,7 +295,7 @@ int BenchmarkBase::CheckThreadNumValid() {
int BenchmarkBase::CheckDeviceTypeValid() { int BenchmarkBase::CheckDeviceTypeValid() {
if (flags_->device_ != "CPU" && flags_->device_ != "GPU" && flags_->device_ != "NPU" && if (flags_->device_ != "CPU" && flags_->device_ != "GPU" && flags_->device_ != "NPU" &&
flags_->device_ != "Ascend310") { flags_->device_ != "Ascend310" && flags_->device_ != "Ascend710") {
MS_LOG(ERROR) << "Device type:" << flags_->device_ << " is not supported."; MS_LOG(ERROR) << "Device type:" << flags_->device_ << " is not supported.";
std::cerr << "Device type:" << flags_->device_ << " is not supported." << std::endl; std::cerr << "Device type:" << flags_->device_ << " is not supported." << std::endl;
return RET_ERROR; return RET_ERROR;

View File

@ -119,7 +119,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
AddFlag(&BenchmarkFlags::model_type_, "modelType", "Input model type. MindIR | MindIR_Opt", "MindIR"); AddFlag(&BenchmarkFlags::model_type_, "modelType", "Input model type. MindIR | MindIR_Opt", "MindIR");
AddFlag(&BenchmarkFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", ""); AddFlag(&BenchmarkFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", "");
AddFlag(&BenchmarkFlags::config_file_, "configFile", "Config file", ""); AddFlag(&BenchmarkFlags::config_file_, "configFile", "Config file", "");
AddFlag(&BenchmarkFlags::device_, "device", "CPU | GPU | NPU | Ascend310", "CPU"); AddFlag(&BenchmarkFlags::device_, "device", "CPU | GPU | NPU | Ascend310 | Ascend710", "CPU");
AddFlag(&BenchmarkFlags::cpu_bind_mode_, "cpuBindMode", "Input 0 for NO_BIND, 1 for HIGHER_CPU, 2 for MID_CPU.", 1); AddFlag(&BenchmarkFlags::cpu_bind_mode_, "cpuBindMode", "Input 0 for NO_BIND, 1 for HIGHER_CPU, 2 for MID_CPU.", 1);
// MarkPerformance // MarkPerformance
AddFlag(&BenchmarkFlags::loop_count_, "loopCount", "Run loop count", 10); AddFlag(&BenchmarkFlags::loop_count_, "loopCount", "Run loop count", 10);

View File

@ -23,7 +23,6 @@
#include <functional> #include <functional>
#include <iomanip> #include <iomanip>
#include <limits> #include <limits>
#include "include/context.h"
#include "include/ms_tensor.h" #include "include/ms_tensor.h"
#include "include/version.h" #include "include/version.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
@ -375,10 +374,10 @@ int BenchmarkUnifiedApi::InitMSContext(const std::shared_ptr<mindspore::Context>
device_list.push_back(npu_device_info); device_list.push_back(npu_device_info);
} }
if (flags_->device_ == "Ascend310") { if (flags_->device_ == "Ascend310" || flags_->device_ == "Ascend710") {
std::shared_ptr<Ascend310DeviceInfo> ascend310_device_info = std::make_shared<Ascend310DeviceInfo>(); std::shared_ptr<AscendDeviceInfo> ascend_device_info = std::make_shared<AscendDeviceInfo>();
ascend310_device_info->SetDeviceID(0); ascend_device_info->SetDeviceID(0);
device_list.push_back(ascend310_device_info); device_list.push_back(ascend_device_info);
} }
// CPU priority is behind GPU and NPU // CPU priority is behind GPU and NPU

View File

@ -343,42 +343,42 @@ STATUS AclPassImpl::ConvertGraphToOm(const FuncGraphPtr &func_graph, Buffer *om_
return lite::RET_OK; return lite::RET_OK;
} }
void AclPassImpl::SetAclModelInitOptions(const std::shared_ptr<Ascend310DeviceInfo> &ascend310_info) { void AclPassImpl::SetAclModelInitOptions(const std::shared_ptr<AscendDeviceInfo> &ascend_info) {
if (!acl_model_option_cfg_.fusion_switch_config_file_path.empty()) { if (!acl_model_option_cfg_.fusion_switch_config_file_path.empty()) {
ascend310_info->SetFusionSwitchConfigPath(acl_model_option_cfg_.fusion_switch_config_file_path); ascend_info->SetFusionSwitchConfigPath(acl_model_option_cfg_.fusion_switch_config_file_path);
} }
if (!acl_model_option_cfg_.op_select_impl_mode.empty()) { if (!acl_model_option_cfg_.op_select_impl_mode.empty()) {
ascend310_info->SetOpSelectImplMode(acl_model_option_cfg_.op_select_impl_mode); ascend_info->SetOpSelectImplMode(acl_model_option_cfg_.op_select_impl_mode);
} }
if (!acl_model_option_cfg_.buffer_optimize.empty()) { if (!acl_model_option_cfg_.buffer_optimize.empty()) {
ascend310_info->SetBufferOptimizeMode(acl_model_option_cfg_.buffer_optimize); ascend_info->SetBufferOptimizeMode(acl_model_option_cfg_.buffer_optimize);
} }
} }
void AclPassImpl::SetAclModelBuildOptions(const std::shared_ptr<Ascend310DeviceInfo> &ascend310_info) { void AclPassImpl::SetAclModelBuildOptions(const std::shared_ptr<AscendDeviceInfo> &ascend_info) {
if (acl_model_option_cfg_.output_type != DataType::kInvalidType) { if (acl_model_option_cfg_.output_type != DataType::kInvalidType) {
ascend310_info->SetOutputType(acl_model_option_cfg_.output_type); ascend_info->SetOutputType(acl_model_option_cfg_.output_type);
} }
if (acl_model_option_cfg_.input_shape_map.size() > 0) { if (acl_model_option_cfg_.input_shape_map.size() > 0) {
ascend310_info->SetInputShapeMap(acl_model_option_cfg_.input_shape_map); ascend_info->SetInputShapeMap(acl_model_option_cfg_.input_shape_map);
} }
if (acl_model_option_cfg_.dynamic_batch_size.size() > 0) { if (acl_model_option_cfg_.dynamic_batch_size.size() > 0) {
ascend310_info->SetDynamicBatchSize(acl_model_option_cfg_.dynamic_batch_size); ascend_info->SetDynamicBatchSize(acl_model_option_cfg_.dynamic_batch_size);
} }
if (!acl_model_option_cfg_.dynamic_image_size.empty()) { if (!acl_model_option_cfg_.dynamic_image_size.empty()) {
ascend310_info->SetDynamicImageSize(acl_model_option_cfg_.dynamic_image_size); ascend_info->SetDynamicImageSize(acl_model_option_cfg_.dynamic_image_size);
} }
if (!acl_model_option_cfg_.input_format.empty()) { if (!acl_model_option_cfg_.input_format.empty()) {
ascend310_info->SetInputFormat(acl_model_option_cfg_.input_format); ascend_info->SetInputFormat(acl_model_option_cfg_.input_format);
} }
if (!acl_model_option_cfg_.input_shape.empty()) { if (!acl_model_option_cfg_.input_shape.empty()) {
ascend310_info->SetInputShape(acl_model_option_cfg_.input_shape); ascend_info->SetInputShape(acl_model_option_cfg_.input_shape);
} }
if (!acl_model_option_cfg_.precision_mode.empty()) { if (!acl_model_option_cfg_.precision_mode.empty()) {
ascend310_info->SetPrecisionMode(acl_model_option_cfg_.precision_mode); ascend_info->SetPrecisionMode(acl_model_option_cfg_.precision_mode);
} }
if (!acl_model_option_cfg_.insert_op_config_file_path.empty()) { if (!acl_model_option_cfg_.insert_op_config_file_path.empty()) {
ascend310_info->SetInsertOpConfigPath(acl_model_option_cfg_.insert_op_config_file_path); ascend_info->SetInsertOpConfigPath(acl_model_option_cfg_.insert_op_config_file_path);
} }
} }
@ -387,15 +387,15 @@ std::shared_ptr<mindspore::Context> AclPassImpl::CreateModelContext() {
if (model_context == nullptr) { if (model_context == nullptr) {
return nullptr; return nullptr;
} }
auto ascend310_info = std::make_shared<Ascend310DeviceInfo>(); auto ascend_info = std::make_shared<AscendDeviceInfo>();
if (ascend310_info == nullptr) { if (ascend_info == nullptr) {
return nullptr; return nullptr;
} }
ascend310_info->SetDeviceID(acl_model_option_cfg_.device_id); ascend_info->SetDeviceID(acl_model_option_cfg_.device_id);
SetAclModelInitOptions(ascend310_info); SetAclModelInitOptions(ascend_info);
SetAclModelBuildOptions(ascend310_info); SetAclModelBuildOptions(ascend_info);
model_context->MutableDeviceInfo().emplace_back(ascend310_info); model_context->MutableDeviceInfo().emplace_back(ascend_info);
return model_context; return model_context;
} }

View File

@ -62,8 +62,8 @@ class AclPassImpl {
STATUS GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph); STATUS GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph);
STATUS TraceOutput(const AnfNodePtr &node); STATUS TraceOutput(const AnfNodePtr &node);
std::shared_ptr<mindspore::Context> CreateModelContext(); std::shared_ptr<mindspore::Context> CreateModelContext();
void SetAclModelInitOptions(const std::shared_ptr<Ascend310DeviceInfo> &ascend310_info); void SetAclModelInitOptions(const std::shared_ptr<AscendDeviceInfo> &ascend_info);
void SetAclModelBuildOptions(const std::shared_ptr<Ascend310DeviceInfo> &ascend310_info); void SetAclModelBuildOptions(const std::shared_ptr<AscendDeviceInfo> &ascend_info);
std::string AdjustCnodeName(const PrimitivePtr &prim); std::string AdjustCnodeName(const PrimitivePtr &prim);
bool IsDynamicInput(); bool IsDynamicInput();

View File

@ -82,7 +82,7 @@ Flags::Flags() {
""); "");
AddFlag(&Flags::graphInputFormatStr, "inputDataFormat", AddFlag(&Flags::graphInputFormatStr, "inputDataFormat",
"Assign the input format of exported model. Only Valid for 4-dimensional input. NHWC | NCHW", "NHWC"); "Assign the input format of exported model. Only Valid for 4-dimensional input. NHWC | NCHW", "NHWC");
AddFlag(&Flags::device, "device", "Set the target device. Only valid when device is Ascend310.", ""); AddFlag(&Flags::device, "device", "Set the target device. Only valid when device is Ascend310 or Ascend710.", "");
} }
int Flags::InitInputOutputDataType() { int Flags::InitInputOutputDataType() {