forked from mindspore-Ecosystem/mindspore
!48214 [MSLITE] 云侧推理接口补齐
Merge pull request !48214 from zhangdanyang/master
This commit is contained in:
commit
b6a7c6df63
|
@ -5,6 +5,7 @@
|
|||
#
|
||||
mindspore/mindspore/core/mindrt/src/thread/actor_threadpool.cc:mindspore::ActorWorker::RunWithSpin
|
||||
mindspore/mindspore/lite/src/common/ops/primitive_c.cc:mindspore::lite::PrimitiveC::Create
|
||||
mindspore/mindspore/lite/src/extendrt/convert/runtime_convert.cc:RuntimeConvert
|
||||
mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc:mindspore::dataset::CsvOp::CsvParser::InitCsvParser
|
||||
mindspore/mindspore/lite/tools/converter/graphdef_transform.cc:mindspore::lite::GraphDefTransform::Transform
|
||||
mindspore/mindspore/lite/providers/nnie_proposal/src/proposal.cc:mindspore::proposal::Rpn
|
||||
|
|
|
@ -421,7 +421,8 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty
|
|||
|
||||
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
||||
const Key &dec_key, const std::string &dec_mode, const std::string &cropto_lib_path) {
|
||||
return Build(StringToChar(model_path), model_type, model_context, dec_key, StringToChar(dec_mode),
|
||||
auto model_path_char = StringToChar(model_path);
|
||||
return Build(model_path_char, model_type, model_context, dec_key, StringToChar(dec_mode),
|
||||
StringToChar(cropto_lib_path));
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,21 @@ static const std::map<enum DataType, std::string> kSupportedDtypeOptionMap = {{D
|
|||
{DataType::kNumberTypeFloat32, "FP32"},
|
||||
{DataType::kNumberTypeUInt8, "UINT8"}};
|
||||
|
||||
std::string TransforPrecisionToAcl(std::string getPrecisionMode) {
|
||||
if (getPrecisionMode == "enforce_fp32") {
|
||||
return "force_fp32";
|
||||
} else if (getPrecisionMode == "preferred_fp32") {
|
||||
return "allow_fp32_to_fp16";
|
||||
} else if (getPrecisionMode == "enforce_fp16") {
|
||||
return "force_fp16";
|
||||
} else if (getPrecisionMode == "enforce_origin") {
|
||||
return "must_keep_origin_dtype";
|
||||
} else if (getPrecisionMode == "preferred_optimal") {
|
||||
return "allow_mix_precision";
|
||||
}
|
||||
return getPrecisionMode;
|
||||
}
|
||||
|
||||
AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
|
||||
if (context == nullptr) {
|
||||
return;
|
||||
|
@ -52,7 +67,7 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
|
|||
}
|
||||
dynamic_batch_size_ = ascend_info->GetDynamicBatchSize();
|
||||
dynamic_image_size_ = ascend_info->GetDynamicImageSize();
|
||||
precision_mode_ = ascend_info->GetPrecisionMode();
|
||||
precision_mode_ = TransforPrecisionToAcl(ascend_info->GetPrecisionMode());
|
||||
op_select_impl_mode_ = ascend_info->GetOpSelectImplMode();
|
||||
fusion_switch_cfg_path_ = ascend_info->GetFusionSwitchConfigPath();
|
||||
device_id_ = ascend_info->GetDeviceID();
|
||||
|
|
|
@ -48,29 +48,50 @@ int RuntimeConvert(const mindspore::api::FuncGraphPtr &graph, const std::shared_
|
|||
std::map<std::string, std::string> ascend_map = config_info.at("ascend_context");
|
||||
mindspore::lite::ConfigFileParser config_parser;
|
||||
config_parser.SetParamByConfigfile(param, ascend_map);
|
||||
} else {
|
||||
auto ascend_info = device->Cast<mindspore::AscendDeviceInfo>();
|
||||
std::string dynamic_batch_size = ascend_info->GetDynamicBatchSize();
|
||||
if (!dynamic_batch_size.empty()) {
|
||||
std::vector<std::string> batch_size_string = mindspore::lite::SplitStringToVector(dynamic_batch_size, ',');
|
||||
for (const auto &item : batch_size_string) {
|
||||
int32_t val;
|
||||
if (mindspore::lite::ConvertIntNum(item, &val)) {
|
||||
size_t tmp_val = static_cast<size_t>(val);
|
||||
param->aclModelOptionCfgParam.dynamic_batch_size.push_back(tmp_val);
|
||||
}
|
||||
}
|
||||
auto ascend_info = device->Cast<mindspore::AscendDeviceInfo>();
|
||||
std::string dynamic_batch_size = ascend_info->GetDynamicBatchSize();
|
||||
if (!dynamic_batch_size.empty()) {
|
||||
std::vector<std::string> batch_size_string = mindspore::lite::SplitStringToVector(dynamic_batch_size, ',');
|
||||
for (const auto &item : batch_size_string) {
|
||||
int32_t val;
|
||||
if (mindspore::lite::ConvertIntNum(item, &val)) {
|
||||
size_t tmp_val = static_cast<size_t>(val);
|
||||
param->aclModelOptionCfgParam.dynamic_batch_size.push_back(tmp_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (ascend_info->GetDeviceID() >= 0) {
|
||||
param->aclModelOptionCfgParam.device_id = ascend_info->GetDeviceID();
|
||||
}
|
||||
if (ascend_info->GetOutputType() != mindspore::DataType::kTypeUnknown) {
|
||||
param->aclModelOptionCfgParam.output_type = ascend_info->GetOutputType();
|
||||
}
|
||||
if (!ascend_info->GetInputShapeMap().empty()) {
|
||||
param->aclModelOptionCfgParam.input_shape_map = ascend_info->GetInputShapeMap();
|
||||
}
|
||||
if (!ascend_info->GetInputFormat().empty()) {
|
||||
param->aclModelOptionCfgParam.input_format = ascend_info->GetInputFormat();
|
||||
}
|
||||
if (!ascend_info->GetInputShape().empty()) {
|
||||
param->aclModelOptionCfgParam.input_shape = ascend_info->GetInputShape();
|
||||
}
|
||||
if (!ascend_info->GetPrecisionMode().empty()) {
|
||||
param->aclModelOptionCfgParam.precision_mode = ascend_info->GetPrecisionMode();
|
||||
}
|
||||
if (!ascend_info->GetOpSelectImplMode().empty()) {
|
||||
param->aclModelOptionCfgParam.op_select_impl_mode = ascend_info->GetOpSelectImplMode();
|
||||
}
|
||||
if (!ascend_info->GetFusionSwitchConfigPath().empty()) {
|
||||
param->aclModelOptionCfgParam.fusion_switch_config_file_path = ascend_info->GetFusionSwitchConfigPath();
|
||||
}
|
||||
if (!ascend_info->GetBufferOptimizeMode().empty()) {
|
||||
param->aclModelOptionCfgParam.buffer_optimize = ascend_info->GetBufferOptimizeMode();
|
||||
}
|
||||
if (!ascend_info->GetInsertOpConfigPath().empty()) {
|
||||
param->aclModelOptionCfgParam.insert_op_config_file_path = ascend_info->GetInsertOpConfigPath();
|
||||
}
|
||||
if (!ascend_info->GetDynamicImageSize().empty()) {
|
||||
param->aclModelOptionCfgParam.dynamic_image_size = ascend_info->GetDynamicImageSize();
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -28,12 +28,10 @@ namespace mindspore {
|
|||
constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
|
||||
constexpr auto kModelOptionGPUEnableFP16 = "mindspore.option.gpu.enable_fp16";
|
||||
constexpr auto kModelOptionNPUEnableFP16 = "mindspore.option.npu.enable_fp16";
|
||||
constexpr auto kModelOptionGPUEnableGLTexture = "mindspore.option.gpu.enable_gl_texture_";
|
||||
constexpr auto kModelOptionGPUGLContext = "mindspore.option.gpu.gl_context_";
|
||||
constexpr auto kModelOptionGPUGLDisplay = "mindspore.option.gpu.gl_display_";
|
||||
constexpr auto kModelOptionGPUDeviceID = "mindspore.option.gpu.device_id";
|
||||
constexpr auto kModelOptionGPURankID = "mindspore.option.gpu.rank_id";
|
||||
constexpr auto kModelOptionGPUGroupSize = "mindspore.option.gpu.group_size";
|
||||
constexpr auto kModelOptionGPUPrecisionMode = "mindspore.option.gpu.precision_mode";
|
||||
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
|
||||
constexpr auto kModelOptionProvider = "mindspore.option.provider";
|
||||
constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
|
||||
|
@ -50,19 +48,6 @@ constexpr auto KModelOptionAscendFusionSwitchCfgPath = "mindspore.option.ascend.
|
|||
constexpr auto kModelOptionAscendDynamicBatchSize = "mindspore.option.ascend.dynamic_batch_size";
|
||||
constexpr auto kModelOptionAscendDynamicImageSize = "mindspore.option.ascend.dynamic_image_size";
|
||||
constexpr auto kModelOptionAscendBufferOptimize = "mindspore.option.ascend.buffer_optimize";
|
||||
constexpr auto kModelOptionGPUPrecisionMode = "mindspore.option.gpu.precision_mode";
|
||||
constexpr auto kModelOptionAscend910DeviceID = kModelOptionDeviceID;
|
||||
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
|
||||
constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path";
|
||||
constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format";
|
||||
constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map";
|
||||
constexpr auto kModelOptionAscend310InputShape = "mindspore.option.ascend310.input_shape";
|
||||
constexpr auto kModelOptionAscend310OutputType = "mindspore.option.ascend310.output_type";
|
||||
constexpr auto kModelOptionAscend310PrecisionMode = "mindspore.option.ascend310.precision_mode";
|
||||
constexpr auto kModelOptionAscend310OpSelectImplMode = "mindspore.option.ascend310.op_select_impl_mode";
|
||||
constexpr auto KModelOptionAscend310FusionSwitchCfgPath = "mindspore.option.ascend310.fusion_switch_config_file_path";
|
||||
constexpr auto kModelOptionAscend310DynamicBatchSize = "mindspore.option.ascend310.dynamic_batch_size";
|
||||
constexpr auto kModelOptionAscend310BufferOptimize = "mindspore.option.ascend310.buffer_optimize";
|
||||
|
||||
Context::Context() : data_(std::make_shared<Data>()) {}
|
||||
|
||||
|
@ -93,6 +78,14 @@ void Context::SetThreadNum(int32_t thread_num) {
|
|||
data_->thread_num = thread_num;
|
||||
}
|
||||
|
||||
int32_t Context::GetThreadNum() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return 0;
|
||||
}
|
||||
return data_->thread_num;
|
||||
}
|
||||
|
||||
void Context::SetInterOpParallelNum(int32_t parallel_num) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
|
@ -109,50 +102,18 @@ int32_t Context::GetInterOpParallelNum() const {
|
|||
return data_->inter_op_parallel_num_;
|
||||
}
|
||||
|
||||
int32_t Context::GetThreadNum() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return 0;
|
||||
}
|
||||
return data_->thread_num;
|
||||
}
|
||||
|
||||
void Context::SetEnableParallel(bool is_parallel) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->enable_parallel_ = is_parallel;
|
||||
}
|
||||
void Context::SetEnableParallel(bool is_parallel) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||
|
||||
bool Context::GetEnableParallel() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return false;
|
||||
}
|
||||
return data_->enable_parallel_;
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return false;
|
||||
}
|
||||
|
||||
void Context::SetThreadAffinity(int mode) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
if (mode < lite::NO_BIND || mode > lite::MID_CPU) {
|
||||
MS_LOG(WARNING) << "Invalid thread affinity mode: " << mode << ", change to NO_BIND mode.";
|
||||
data_->affinity_mode_ = lite::NO_BIND;
|
||||
return;
|
||||
}
|
||||
data_->affinity_mode_ = mode;
|
||||
return;
|
||||
}
|
||||
void Context::SetThreadAffinity(int mode) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||
|
||||
int Context::GetThreadAffinityMode() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return -1;
|
||||
}
|
||||
return data_->affinity_mode_;
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Context::SetThreadAffinity(const std::vector<int> &core_list) {
|
||||
|
@ -173,42 +134,39 @@ std::vector<int32_t> Context::GetThreadAffinityCoreList() const {
|
|||
return data_->affinity_core_list_;
|
||||
}
|
||||
|
||||
void Context::SetBuiltInDelegate(DelegateMode mode) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||
|
||||
DelegateMode Context::GetBuiltInDelegate() const {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kNoDelegate;
|
||||
}
|
||||
|
||||
void Context::set_delegate(const std::shared_ptr<AbstractDelegate> &delegate) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->delegate = std::dynamic_pointer_cast<GraphSinkDelegate>(delegate);
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
}
|
||||
|
||||
std::shared_ptr<AbstractDelegate> Context::get_delegate() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return nullptr;
|
||||
}
|
||||
return data_->delegate;
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void Context::SetDelegate(const std::shared_ptr<Delegate> &delegate) { return; }
|
||||
void Context::SetDelegate(const std::shared_ptr<Delegate> &delegate) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
std::shared_ptr<Delegate> Context::GetDelegate() const { return nullptr; }
|
||||
|
||||
void Context::SetMultiModalHW(bool float_mode) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->float_mode = float_mode;
|
||||
std::shared_ptr<Delegate> Context::GetDelegate() const {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Context::SetMultiModalHW(bool float_mode) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||
|
||||
bool Context::GetMultiModalHW() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return false;
|
||||
}
|
||||
return data_->float_mode;
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() const {
|
||||
|
@ -293,6 +251,11 @@ void GPUDeviceInfo::SetEnableFP16(bool is_fp16) {
|
|||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
if (is_fp16) {
|
||||
data_->params[kModelOptionGPUPrecisionMode] = std::string("preferred_fp16");
|
||||
} else {
|
||||
data_->params[kModelOptionGPUPrecisionMode] = std::string("enforce_fp32");
|
||||
}
|
||||
data_->params[kModelOptionGPUEnableFP16] = is_fp16;
|
||||
}
|
||||
|
||||
|
@ -304,77 +267,50 @@ bool GPUDeviceInfo::GetEnableFP16() const {
|
|||
return GetValue<bool>(data_, kModelOptionGPUEnableFP16);
|
||||
}
|
||||
|
||||
void GPUDeviceInfo::SetEnableGLTexture(bool is_enable_gl_texture) {
|
||||
void GPUDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionGPUEnableGLTexture] = is_enable_gl_texture;
|
||||
if (precision_mode == StringToChar("enforce_fp32")) {
|
||||
data_->params[kModelOptionGPUEnableFP16] = false;
|
||||
} else if (precision_mode == StringToChar("preferred_fp16")) {
|
||||
data_->params[kModelOptionGPUEnableFP16] = true;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "GPU only support mode enforce_fp32 and preferred_fp16.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionGPUPrecisionMode] = CharToString(precision_mode);
|
||||
}
|
||||
|
||||
std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
}
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionGPUPrecisionMode);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void GPUDeviceInfo::SetEnableGLTexture(bool is_enable_gl_texture) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||
|
||||
bool GPUDeviceInfo::GetEnableGLTexture() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return false;
|
||||
}
|
||||
return GetValue<bool>(data_, kModelOptionGPUEnableGLTexture);
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return false;
|
||||
}
|
||||
|
||||
void GPUDeviceInfo::SetGLContext(void *gl_context) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionGPUGLContext] = gl_context;
|
||||
}
|
||||
void GPUDeviceInfo::SetGLContext(void *gl_context) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||
|
||||
void *GPUDeviceInfo::GetGLContext() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return nullptr;
|
||||
}
|
||||
return GetValue<void *>(data_, kModelOptionGPUGLContext);
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void GPUDeviceInfo::SetGLDisplay(void *gl_display) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionGPUGLDisplay] = gl_display;
|
||||
}
|
||||
void GPUDeviceInfo::SetGLDisplay(void *gl_display) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||
|
||||
void *GPUDeviceInfo::GetGLDisplay() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return nullptr;
|
||||
}
|
||||
return GetValue<void *>(data_, kModelOptionGPUGLDisplay);
|
||||
}
|
||||
|
||||
void KirinNPUDeviceInfo::SetEnableFP16(bool is_fp16) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionNPUEnableFP16] = is_fp16;
|
||||
}
|
||||
bool KirinNPUDeviceInfo::GetEnableFP16() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
return GetValue<bool>(data_, kModelOptionNPUEnableFP16);
|
||||
}
|
||||
|
||||
void KirinNPUDeviceInfo::SetFrequency(int frequency) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionKirinNpuFrequency] = frequency;
|
||||
}
|
||||
|
||||
int KirinNPUDeviceInfo::GetFrequency() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return 0;
|
||||
}
|
||||
return GetValue<int>(data_, kModelOptionKirinNpuFrequency);
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void GPUDeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||
|
@ -403,16 +339,6 @@ int GPUDeviceInfo::GetGroupSize() const {
|
|||
return GetValue<int>(data_, kModelOptionGPUGroupSize);
|
||||
}
|
||||
|
||||
void GPUDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
}
|
||||
|
||||
std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
std::vector<char> ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
void AscendDeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
|
@ -436,6 +362,7 @@ void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path)
|
|||
}
|
||||
data_->params[kModelOptionAscendInsertOpCfgPath] = CharToString(cfg_path);
|
||||
}
|
||||
|
||||
std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
|
@ -469,6 +396,7 @@ void AscendDeviceInfo::SetInputShape(const std::vector<char> &shape) {
|
|||
}
|
||||
data_->params[kModelOptionAscendInputShape] = CharToString(shape);
|
||||
}
|
||||
|
||||
std::vector<char> AscendDeviceInfo::GetInputShapeChar() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
|
@ -478,6 +406,22 @@ std::vector<char> AscendDeviceInfo::GetInputShapeChar() const {
|
|||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionAscendInputShapeMap] = shape;
|
||||
}
|
||||
|
||||
std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::map<int, std::vector<int>>();
|
||||
}
|
||||
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscendInputShapeMap);
|
||||
}
|
||||
|
||||
void AscendDeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
|
@ -524,6 +468,12 @@ void AscendDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode)
|
|||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
if (precision_mode != StringToChar("enforce_fp32") && precision_mode != StringToChar("preferred_fp32") &&
|
||||
precision_mode != StringToChar("enforce_fp16") && precision_mode != StringToChar("enforce_origin") &&
|
||||
precision_mode != StringToChar("preferred_optimal")) {
|
||||
MS_LOG(ERROR) << "Ascend can not support " << CharToString(precision_mode) << " mode.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionAscendPrecisionMode] = CharToString(precision_mode);
|
||||
}
|
||||
|
||||
|
@ -560,6 +510,7 @@ void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_pa
|
|||
}
|
||||
data_->params[KModelOptionAscendFusionSwitchCfgPath] = CharToString(cfg_path);
|
||||
}
|
||||
|
||||
std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
|
@ -569,22 +520,6 @@ std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const {
|
|||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionAscendInputShapeMap] = shape;
|
||||
}
|
||||
|
||||
std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::map<int, std::vector<int>>();
|
||||
}
|
||||
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscendInputShapeMap);
|
||||
}
|
||||
|
||||
void AscendDeviceInfo::SetOutputType(enum DataType output_type) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
|
|
|
@ -16,6 +16,10 @@
|
|||
#include "include/api/model.h"
|
||||
#include "include/api/context.h"
|
||||
#include "extendrt/cxx_api/model/model_impl.h"
|
||||
#ifdef ENABLE_OPENSSL
|
||||
#include "src/common/decrypt.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
|
@ -43,6 +47,24 @@ Model::Model() : impl_(nullptr) {
|
|||
|
||||
Model::~Model() {}
|
||||
|
||||
#ifdef ENABLE_OPENSSL
|
||||
Status DecryptModel(const std::string &cropto_lib_path, const void *model_buf, size_t model_size, const Key &dec_key,
|
||||
const std::string &dec_mode, std::unique_ptr<Byte[]> *decrypt_buffer, size_t *decrypt_len) {
|
||||
if (model_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "model_buf is nullptr.";
|
||||
return kLiteError;
|
||||
}
|
||||
*decrypt_len = 0;
|
||||
*decrypt_buffer = lite::Decrypt(cropto_lib_path, decrypt_len, reinterpret_cast<const Byte *>(model_buf), model_size,
|
||||
dec_key.key, dec_key.len, dec_mode);
|
||||
if (*decrypt_buffer == nullptr || *decrypt_len == 0) {
|
||||
MS_LOG(ERROR) << "Decrypt buffer failed";
|
||||
return kLiteError;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
#endif
|
||||
|
||||
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context) {
|
||||
std::unique_lock<std::mutex> build_lock(g_build_mutex);
|
||||
|
@ -90,11 +112,116 @@ Status Model::Build(const std::vector<char> &model_path, ModelType model_type,
|
|||
}
|
||||
}
|
||||
|
||||
// to do, now just to adapter benchmark
|
||||
Status Model::Build(const std::vector<char> &model_path, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const Key &dec_key,
|
||||
const std::vector<char> &dec_mode, const std::vector<char> &cropto_lib_path) {
|
||||
return Build(model_path, model_type, model_context);
|
||||
#ifdef ENABLE_OPENSSL
|
||||
std::unique_lock<std::mutex> build_lock(g_build_mutex);
|
||||
if (impl_ == nullptr) {
|
||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||
impl_ = std::make_shared<ModelImpl>();
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteFileError;
|
||||
}
|
||||
}
|
||||
|
||||
if (dec_key.len > 0) {
|
||||
size_t model_size;
|
||||
auto model_buf = lite::ReadFile(model_path.data(), &model_size);
|
||||
if (model_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "Read model file failed";
|
||||
return kLiteFileError;
|
||||
}
|
||||
std::unique_ptr<Byte[]> decrypt_buffer;
|
||||
size_t decrypt_len = 0;
|
||||
Status ret = DecryptModel(CharToString(cropto_lib_path), model_buf, model_size, dec_key, CharToString(dec_mode),
|
||||
&decrypt_buffer, &decrypt_len);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Decrypt model failed.";
|
||||
delete[] model_buf;
|
||||
return ret;
|
||||
}
|
||||
try {
|
||||
ret = impl_->Build(decrypt_buffer.get(), decrypt_len, model_type, model_context);
|
||||
if (ret != kSuccess) {
|
||||
delete[] model_buf;
|
||||
return ret;
|
||||
}
|
||||
delete[] model_buf;
|
||||
return kSuccess;
|
||||
} catch (const std::exception &exe) {
|
||||
delete[] model_buf;
|
||||
MS_LOG_ERROR << "Catch exception: " << exe.what();
|
||||
return kCoreFailed;
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
Status ret = impl_->Build(CharToString(model_path), model_type, model_context);
|
||||
if (ret != kSuccess) {
|
||||
return ret;
|
||||
}
|
||||
return kSuccess;
|
||||
} catch (const std::exception &exe) {
|
||||
MS_LOG_ERROR << "Catch exception: " << exe.what();
|
||||
return kCoreFailed;
|
||||
}
|
||||
}
|
||||
#else
|
||||
MS_LOG(ERROR) << "The lib is not support Decrypt Model.";
|
||||
return kLiteError;
|
||||
#endif
|
||||
}
|
||||
|
||||
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const Key &dec_key,
|
||||
const std::vector<char> &dec_mode, const std::vector<char> &cropto_lib_path) {
|
||||
#ifdef ENABLE_OPENSSL
|
||||
std::unique_lock<std::mutex> build_lock(g_build_mutex);
|
||||
if (impl_ == nullptr) {
|
||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||
impl_ = std::make_shared<ModelImpl>();
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteFileError;
|
||||
}
|
||||
}
|
||||
|
||||
if (dec_key.len > 0) {
|
||||
std::unique_ptr<Byte[]> decrypt_buffer;
|
||||
size_t decrypt_len = 0;
|
||||
Status ret = DecryptModel(CharToString(cropto_lib_path), model_data, data_size, dec_key, CharToString(dec_mode),
|
||||
&decrypt_buffer, &decrypt_len);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Decrypt model failed.";
|
||||
return ret;
|
||||
}
|
||||
try {
|
||||
ret = impl_->Build(decrypt_buffer.get(), decrypt_len, model_type, model_context);
|
||||
if (ret != kSuccess) {
|
||||
return ret;
|
||||
}
|
||||
return kSuccess;
|
||||
} catch (const std::exception &exe) {
|
||||
MS_LOG_ERROR << "Catch exception: " << exe.what();
|
||||
return kCoreFailed;
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
Status ret = impl_->Build(model_data, data_size, model_type, model_context);
|
||||
if (ret != kSuccess) {
|
||||
return ret;
|
||||
}
|
||||
return kSuccess;
|
||||
} catch (const std::exception &exe) {
|
||||
MS_LOG_ERROR << "Catch exception: " << exe.what();
|
||||
return kCoreFailed;
|
||||
}
|
||||
}
|
||||
#else
|
||||
MS_LOG(ERROR) << "The lib is not support Decrypt Model.";
|
||||
return kLiteError;
|
||||
#endif
|
||||
}
|
||||
|
||||
Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_context,
|
||||
|
@ -103,6 +230,18 @@ Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_conte
|
|||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
Status Model::Build(GraphCell graph, Node *optimizer, std::vector<Expr *> inputs,
|
||||
const std::shared_ptr<Context> &model_context, const std::shared_ptr<TrainCfg> &train_cfg) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
Status BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr<Context> &context,
|
||||
const std::shared_ptr<TrainCfg> &train_cfg = nullptr) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed because this model has not been built.";
|
||||
|
@ -140,29 +279,25 @@ Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor>
|
|||
}
|
||||
}
|
||||
|
||||
Status Model::Predict(const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
Status Model::PredictWithPreprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed because this model has not been built.";
|
||||
return kMCFailed;
|
||||
}
|
||||
return impl_->PredictWithPreprocess(inputs, outputs);
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
Status Model::Preprocess(const std::vector<std::vector<MSTensor>> &inputs, std::vector<MSTensor> *outputs) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed because this model has not been built.";
|
||||
return kMCFailed;
|
||||
}
|
||||
return impl_->Preprocess(inputs, outputs);
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
bool Model::HasPreprocess() {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed because this model has not been built.";
|
||||
return false;
|
||||
}
|
||||
return impl_->HasPreprocess();
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> Model::GetInputs() {
|
||||
|
@ -222,12 +357,14 @@ MSTensor Model::GetOutputByTensorName(const std::vector<char> &name) {
|
|||
}
|
||||
|
||||
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::vector<char> &node_name) {
|
||||
return std::vector<MSTensor>();
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return {};
|
||||
}
|
||||
|
||||
Status Model::BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGLTexture,
|
||||
std::map<std::string, unsigned int> *outputGLTexture) {
|
||||
return kSuccess;
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
Status Model::LoadConfig(const std::vector<char> &config_path) {
|
||||
|
@ -267,4 +404,91 @@ Status Model::UpdateConfig(const std::vector<char> §ion,
|
|||
bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) {
|
||||
return ModelImpl::CheckModelSupport(device_type, model_type);
|
||||
}
|
||||
|
||||
Status Model::UpdateWeights(const std::vector<MSTensor> &new_weights) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> Model::GetTrainableParams() const {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<MSTensor> Model::GetGradients() const {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return {};
|
||||
}
|
||||
|
||||
Status Model::ApplyGradients(const std::vector<MSTensor> &gradients) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> Model::GetFeatureMaps() const {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return {};
|
||||
}
|
||||
|
||||
Status Model::UpdateFeatureMaps(const std::vector<MSTensor> &new_weights) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> Model::GetOptimizerParams() const {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return {};
|
||||
}
|
||||
|
||||
Status Model::SetOptimizerParams(const std::vector<MSTensor> ¶ms) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
Status Model::SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
Status Model::SetLearningRate(float learning_rate) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
float Model::GetLearningRate() {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
Status Model::InitMetrics(const std::vector<Metrics *> metrics) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
std::vector<Metrics *> Model::GetMetrics() {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return {};
|
||||
}
|
||||
|
||||
Status Model::SetTrainMode(bool train) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
bool Model::GetTrainMode() const {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// cppcheck-suppress passedByValue
|
||||
Status Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
|
||||
// cppcheck-suppress passedByValue
|
||||
Status Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -84,6 +84,10 @@ ConverterFunc ConverterPlugin::GetConverterFunc() {
|
|||
|
||||
Status ModelImpl::BuildByBufferImpl(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const std::string &model_path) {
|
||||
if (model_type != kMindIR) {
|
||||
MS_LOG(ERROR) << "Invalid model type";
|
||||
return kLiteError;
|
||||
}
|
||||
const void *model_buff = model_data;
|
||||
size_t model_size = data_size;
|
||||
auto mindir_path = GetConfig(lite::kConfigModelFileSection, lite::kConfigMindIRPathKey);
|
||||
|
@ -211,6 +215,16 @@ Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<
|
|||
MS_LOG(ERROR) << "Dims is null.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
for (size_t j = 0; j < dims.size(); j++) {
|
||||
auto dims_v = dims[j];
|
||||
for (size_t i = 0; i < dims_v.size(); i++) {
|
||||
auto dim = dims_v[i];
|
||||
if (dim <= 0 || dim > INT_MAX) {
|
||||
MS_LOG(ERROR) << "Invalid shape! dim: " << dim;
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (inputs.size() != dims.size()) {
|
||||
MS_LOG(ERROR) << "The size of inputs does not match the size of dims.";
|
||||
return kLiteInputParamInvalid;
|
||||
|
|
|
@ -367,10 +367,7 @@ int TensorRTExecutor::ParseOptimizationProfile() {
|
|||
}
|
||||
trt_profile_configs_ = profile_configs;
|
||||
auto precision_mode = ProfileParser::GetOption(gpu_context, lite::kPrecisionModeKey, "");
|
||||
if (precision_mode == "force_fp16") {
|
||||
device_info_->SetEnableFP16(true);
|
||||
MS_LOG(INFO) << "Set precision mode to fp16 by config file";
|
||||
}
|
||||
device_info_->SetPrecisionMode(precision_mode);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue