From 0a0419498f5349aa365d0a1235fead3204bbfff4 Mon Sep 17 00:00:00 2001 From: xuyongfei Date: Thu, 9 Dec 2021 15:20:31 +0800 Subject: [PATCH] C++ inference interface Ascend310DeviceInfo+Ascend910DeviceInfo->AscendDeviceInfo --- include/api/context.h | 62 +- include/api/model.h | 2 +- mindspore/ccsrc/cxx_api/acl_utils.h | 48 ++ mindspore/ccsrc/cxx_api/context.cc | 57 +- mindspore/ccsrc/cxx_api/factory.h | 66 +- .../ccsrc/cxx_api/graph/acl/acl_graph_impl.cc | 11 +- .../ccsrc/cxx_api/graph/acl/acl_graph_impl.h | 1 + .../cxx_api/graph/ascend/ascend_graph_impl.cc | 11 +- .../cxx_api/graph/ascend/ascend_graph_impl.h | 1 + .../ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc | 4 +- .../ccsrc/cxx_api/graph/gpu/gpu_graph_impl.h | 2 + mindspore/ccsrc/cxx_api/graph/graph_impl.h | 2 + .../ccsrc/cxx_api/model/acl/acl_model.cc | 19 +- mindspore/ccsrc/cxx_api/model/acl/acl_model.h | 1 + .../cxx_api/model/acl/acl_model_multi.cc | 2 +- .../cxx_api/model/acl/acl_model_options.cc | 2 +- mindspore/ccsrc/cxx_api/model/model.cc | 23 +- mindspore/ccsrc/cxx_api/model/model_impl.h | 3 +- mindspore/ccsrc/cxx_api/model/ms/ms_model.cc | 24 +- mindspore/ccsrc/cxx_api/model/ms/ms_model.h | 1 + .../mnist_stm32f746/include/api/context.h | 571 ++++++++++++------ mindspore/lite/src/cxx_api/context.cc | 54 +- mindspore/lite/src/cxx_api/converters.cc | 2 +- tests/st/cpp/common/common_test.cc | 4 +- tests/st/cpp/dataset/test_de.cc | 6 +- tests/st/cpp/model/test_dynamic_batch_size.cc | 2 +- tests/st/cpp/model/test_zero_copy.cc | 4 +- tests/ut/cpp/cxx_api/context_test.cc | 21 +- 28 files changed, 632 insertions(+), 374 deletions(-) create mode 100644 mindspore/ccsrc/cxx_api/acl_utils.h diff --git a/include/api/context.h b/include/api/context.h index 552dd4f8710..92bb2f5e008 100644 --- a/include/api/context.h +++ b/include/api/context.h @@ -28,6 +28,7 @@ enum DeviceType { kCPU = 0, kGPU, kKirinNPU, + kAscend, kAscend910, kAscend310, // add new type here @@ -287,34 +288,14 @@ void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) { } std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } -/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend910. This option is -/// invalid for MindSpore Lite. -class MS_API Ascend910DeviceInfo : public DeviceInfoContext { - public: - /// \brief Get the type of this DeviceInfoContext. - /// - /// \return Type of this DeviceInfoContext. - enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; }; - - /// \brief Set device id. - /// - /// \param[in] device_id The device id. - void SetDeviceID(uint32_t device_id); - - /// \brief Get the device id. - /// - /// \return The device id. - uint32_t GetDeviceID() const; -}; - /// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is /// invalid for MindSpore Lite. -class MS_API Ascend310DeviceInfo : public DeviceInfoContext { +class MS_API AscendDeviceInfo : public DeviceInfoContext { public: /// \brief Get the type of this DeviceInfoContext. /// /// \return Type of this DeviceInfoContext. - enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; }; + enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; }; /// \brief Set device id. /// @@ -447,45 +428,48 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext { std::vector GetBufferOptimizeModeChar() const; }; -void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { +using Ascend310DeviceInfo = AscendDeviceInfo; +using Ascend910DeviceInfo = AscendDeviceInfo; + +void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { SetInsertOpConfigPath(StringToChar(cfg_path)); } -std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); } +std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); } -void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); } -std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); } +void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); } +std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); } -void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); } -std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); } +void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); } +std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); } -std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); } +std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); } -void Ascend310DeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) { +void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) { SetDynamicImageSize(StringToChar(dynamic_image_size)); } -std::string Ascend310DeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); } +std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); } -void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) { +void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) { SetPrecisionMode(StringToChar(precision_mode)); } -std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } +std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } -void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) { +void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) { SetOpSelectImplMode(StringToChar(op_select_impl_mode)); } -std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); } +std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); } -void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) { +void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) { SetFusionSwitchConfigPath(StringToChar(cfg_path)); } -std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const { +std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const { return CharToString(GetFusionSwitchConfigPathChar()); } -void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) { +void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) { SetBufferOptimizeMode(StringToChar(buffer_optimize_mode)); } -std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } +std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CONTEXT_H diff --git a/include/api/model.h b/include/api/model.h index 007e584a8a8..2b77eeccc70 100644 --- a/include/api/model.h +++ b/include/api/model.h @@ -193,7 +193,7 @@ class MS_API Model { /// \brief Inference model. /// - /// \param[in] device_type Device type,options are kGPU, kAscend910, etc. + /// \param[in] device_type Device type,options are kGPU, kAscend, kAscend910, etc. /// \param[in] model_type The type of model file, options are ModelType::kMindIR, ModelType::kOM. /// /// \return Is supported or not. diff --git a/mindspore/ccsrc/cxx_api/acl_utils.h b/mindspore/ccsrc/cxx_api/acl_utils.h new file mode 100644 index 00000000000..2a1a5614955 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/acl_utils.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H +#define MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H + +#include +#include "acl/acl_base.h" +namespace mindspore { +static inline bool IsAscend910Soc() { + const char *soc_name_c = aclrtGetSocName(); + if (soc_name_c == nullptr) { + return false; + } + std::string soc_name(soc_name_c); + if (soc_name.find("910") == std::string::npos) { + return false; + } + return true; +} + +static inline bool IsAscendNo910Soc() { + const char *soc_name_c = aclrtGetSocName(); + if (soc_name_c == nullptr) { + return false; + } + std::string soc_name(soc_name_c); + if (soc_name.find("910") != std::string::npos) { + return false; + } + return true; +} +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H diff --git a/mindspore/ccsrc/cxx_api/context.cc b/mindspore/ccsrc/cxx_api/context.cc index 22b03a19add..7db70c8bff8 100644 --- a/mindspore/ccsrc/cxx_api/context.cc +++ b/mindspore/ccsrc/cxx_api/context.cc @@ -175,55 +175,46 @@ std::vector GPUDeviceInfo::GetPrecisionModeChar() const { return StringToChar(ref); } -void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) { - MS_EXCEPTION_IF_NULL(data_); - data_->params[kModelOptionAscend910DeviceID] = device_id; -} -uint32_t Ascend910DeviceInfo::GetDeviceID() const { - MS_EXCEPTION_IF_NULL(data_); - return GetValue(data_, kModelOptionAscend910DeviceID); -} - -void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) { +void AscendDeviceInfo::SetDeviceID(uint32_t device_id) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310DeviceID] = device_id; } -uint32_t Ascend310DeviceInfo::GetDeviceID() const { +uint32_t AscendDeviceInfo::GetDeviceID() const { MS_EXCEPTION_IF_NULL(data_); return GetValue(data_, kModelOptionAscend310DeviceID); } -void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector &cfg_path) { +void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector &cfg_path) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path); } -std::vector Ascend310DeviceInfo::GetInsertOpConfigPathChar() const { +std::vector AscendDeviceInfo::GetInsertOpConfigPathChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, kModelOptionAscend310InsertOpCfgPath); return StringToChar(ref); } -void Ascend310DeviceInfo::SetInputFormat(const std::vector &format) { +void AscendDeviceInfo::SetInputFormat(const std::vector &format) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310InputFormat] = CharToString(format); } -std::vector Ascend310DeviceInfo::GetInputFormatChar() const { +std::vector AscendDeviceInfo::GetInputFormatChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, kModelOptionAscend310InputFormat); return StringToChar(ref); } -void Ascend310DeviceInfo::SetInputShape(const std::vector &shape) { +void AscendDeviceInfo::SetInputShape(const std::vector &shape) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310InputShape] = CharToString(shape); } -std::vector Ascend310DeviceInfo::GetInputShapeChar() const { +std::vector AscendDeviceInfo::GetInputShapeChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, kModelOptionAscend310InputShape); return StringToChar(ref); } -void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector &dynamic_batch_size) { +void AscendDeviceInfo::SetDynamicBatchSize(const std::vector &dynamic_batch_size) { MS_EXCEPTION_IF_NULL(data_); std::string batchs = ""; for (size_t i = 0; i < dynamic_batch_size.size(); ++i) { @@ -234,69 +225,69 @@ void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector &dynamic } data_->params[kModelOptionAscend310DynamicBatchSize] = batchs; } -std::vector Ascend310DeviceInfo::GetDynamicBatchSizeChar() const { +std::vector AscendDeviceInfo::GetDynamicBatchSizeChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, kModelOptionAscend310DynamicBatchSize); return StringToChar(ref); } -void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector &dynamic_image_size) { return; } +void AscendDeviceInfo::SetDynamicImageSize(const std::vector &dynamic_image_size) { return; } -std::vector Ascend310DeviceInfo::GetDynamicImageSizeChar() const { return std::vector(); } +std::vector AscendDeviceInfo::GetDynamicImageSizeChar() const { return std::vector(); } -void Ascend310DeviceInfo::SetPrecisionMode(const std::vector &precision_mode) { +void AscendDeviceInfo::SetPrecisionMode(const std::vector &precision_mode) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode); } -std::vector Ascend310DeviceInfo::GetPrecisionModeChar() const { +std::vector AscendDeviceInfo::GetPrecisionModeChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, kModelOptionAscend310PrecisionMode); return StringToChar(ref); } -void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector &op_select_impl_mode) { +void AscendDeviceInfo::SetOpSelectImplMode(const std::vector &op_select_impl_mode) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode); } -std::vector Ascend310DeviceInfo::GetOpSelectImplModeChar() const { +std::vector AscendDeviceInfo::GetOpSelectImplModeChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, kModelOptionAscend310OpSelectImplMode); return StringToChar(ref); } -void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector &cfg_path) { +void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector &cfg_path) { MS_EXCEPTION_IF_NULL(data_); data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path); } -std::vector Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const { +std::vector AscendDeviceInfo::GetFusionSwitchConfigPathChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, KModelOptionAscend310FusionSwitchCfgPath); return StringToChar(ref); } -void Ascend310DeviceInfo::SetInputShapeMap(const std::map> &shape) { +void AscendDeviceInfo::SetInputShapeMap(const std::map> &shape) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310InputShapeMap] = shape; } -std::map> Ascend310DeviceInfo::GetInputShapeMap() const { +std::map> AscendDeviceInfo::GetInputShapeMap() const { MS_EXCEPTION_IF_NULL(data_); return GetValue>>(data_, kModelOptionAscend310InputShapeMap); } -void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) { +void AscendDeviceInfo::SetOutputType(enum DataType output_type) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310OutputType] = output_type; } -enum DataType Ascend310DeviceInfo::GetOutputType() const { +enum DataType AscendDeviceInfo::GetOutputType() const { MS_EXCEPTION_IF_NULL(data_); return GetValue(data_, kModelOptionAscend310OutputType); } -void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector &buffer_optimize_mode) { +void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector &buffer_optimize_mode) { MS_EXCEPTION_IF_NULL(data_); data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode); } -std::vector Ascend310DeviceInfo::GetBufferOptimizeModeChar() const { +std::vector AscendDeviceInfo::GetBufferOptimizeModeChar() const { MS_EXCEPTION_IF_NULL(data_); const std::string &ref = GetValue(data_, kModelOptionAscend310BufferOptimize); return StringToChar(ref); diff --git a/mindspore/ccsrc/cxx_api/factory.h b/mindspore/ccsrc/cxx_api/factory.h index 7cc175b9d84..421a6e78429 100644 --- a/mindspore/ccsrc/cxx_api/factory.h +++ b/mindspore/ccsrc/cxx_api/factory.h @@ -24,7 +24,31 @@ #include "utils/utils.h" namespace mindspore { -inline std::string g_device_target = "Default"; +inline enum DeviceType g_device_target = kInvalidDeviceType; + +static inline LogStream &operator<<(LogStream &stream, DeviceType device_type) { + switch (device_type) { + case kAscend: + stream << "Ascend"; + break; + case kAscend910: + stream << "Ascend910"; + break; + case kAscend310: + stream << "Ascend310"; + break; + case kGPU: + stream << "GPU"; + break; + case kCPU: + stream << "CPU"; + break; + default: + stream << "[InvalidDeviceType: " << static_cast(device_type) << "]"; + break; + } + return stream; +} template class Factory { @@ -39,32 +63,24 @@ class Factory { return instance; } - void Register(const std::string &device_name, U &&creator) { - if (creators_.find(device_name) == creators_.end()) { - (void)creators_.emplace(device_name, creator); + void Register(U &&creator) { creators_.push_back(creator); } + + std::shared_ptr Create(enum DeviceType device_type) { + for (auto &item : creators_) { + MS_EXCEPTION_IF_NULL(item); + auto val = item(); + if (val->CheckDeviceSupport(device_type)) { + return val; + } } - } - - bool CheckModelSupport(const std::string &device_name) { - return std::any_of(creators_.begin(), creators_.end(), - [&device_name](const std::pair &item) { return item.first == device_name; }); - } - - std::shared_ptr Create(const std::string &device_name) { - auto iter = creators_.find(device_name); - if (creators_.end() != iter) { - MS_EXCEPTION_IF_NULL(iter->second); - return (iter->second)(); - } - - MS_LOG(ERROR) << "Unsupported device target " << device_name; + MS_LOG(WARNING) << "Unsupported device target " << device_type; return nullptr; } private: Factory() = default; ~Factory() = default; - std::map creators_; + std::vector creators_; }; template @@ -72,14 +88,12 @@ class Registrar { using U = std::function()>; public: - Registrar(const std::string &device_name, U creator) { - Factory::Instance().Register(device_name, std::move(creator)); - } + explicit Registrar(U creator) { Factory::Instance().Register(std::move(creator)); } ~Registrar() = default; }; -#define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \ - static const Registrar g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \ - #DEVICE_NAME, []() { return std::make_shared(); }); +#define API_FACTORY_REG(BASE_CLASS, DERIVE_CLASS) \ + static const Registrar g_api_##DERIVE_CLASS##_registrar_reg( \ + []() { return std::make_shared(); }); } // namespace mindspore #endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H diff --git a/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc index 063961f8179..f0b58e77d6a 100644 --- a/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc +++ b/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc @@ -18,9 +18,10 @@ #include "cxx_api/model/acl/model_converter.h" #include "utils/log_adapter.h" #include "mindspore/core/utils/convert_utils_base.h" +#include "cxx_api/acl_utils.h" namespace mindspore { -API_FACTORY_REG(GraphCell::GraphImpl, Ascend310, AclGraphImpl); +API_FACTORY_REG(GraphCell::GraphImpl, AclGraphImpl); AclGraphImpl::AclGraphImpl() : init_flag_(false), @@ -231,4 +232,12 @@ Status AclGraphImpl::ConvertToOM() { MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType(); return kMCFailed; } + +bool AclGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) { + // for Ascend, only support kAscend and kAscend310 + if (device_type != kAscend && device_type != kAscend310) { + return false; + } + return IsAscendNo910Soc(); +} } // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.h b/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.h index ab4994c83df..61ac4b39527 100644 --- a/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.h +++ b/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.h @@ -37,6 +37,7 @@ class AclGraphImpl : public GraphCell::GraphImpl { Status Load(uint32_t device_id) override; std::vector GetInputs() override; std::vector GetOutputs() override; + bool CheckDeviceSupport(mindspore::DeviceType device_type) override; private: Status ConvertToOM(); diff --git a/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc index 05b74dd2954..5b190287d7e 100644 --- a/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc +++ b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc @@ -18,6 +18,7 @@ #include "include/api/context.h" #include "cxx_api/factory.h" #include "cxx_api/akg_kernel_register.h" +#include "cxx_api/acl_utils.h" #include "utils/log_adapter.h" #include "utils/context/context_extends.h" #include "mindspore/core/base/base_ref_utils.h" @@ -30,7 +31,7 @@ #include "pybind11/pybind11.h" namespace mindspore { -API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl); +API_FACTORY_REG(GraphCell::GraphImpl, AscendGraphImpl); static constexpr const char *kHcclEnable = "MS_ENABLE_HCCL"; static constexpr const char *kHcclGroupFile = "PARA_GROUP_FILE"; @@ -382,6 +383,14 @@ std::shared_ptr AscendGraphImpl::MsEnvGuard::GetEnv return acl_env; } +bool AscendGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) { + // for Ascend, only support kAscend and kAscend910 + if (device_type != kAscend && device_type != kAscend910) { + return false; + } + return IsAscend910Soc(); +} + std::map> AscendGraphImpl::MsEnvGuard::global_ms_env_; std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_; diff --git a/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.h b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.h index d27a8eb6dca..a38d5b491f3 100644 --- a/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.h +++ b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.h @@ -39,6 +39,7 @@ class AscendGraphImpl : public GraphCell::GraphImpl { Status Load(uint32_t device_id) override; std::vector GetInputs() override; std::vector GetOutputs() override; + bool CheckDeviceSupport(mindspore::DeviceType device_type) override; private: class MsEnvGuard; diff --git a/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc index 0a4350a9258..f55f90a0d84 100644 --- a/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc +++ b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc @@ -26,7 +26,7 @@ #include "runtime/device/gpu/cuda_driver.h" namespace mindspore { -API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl); +API_FACTORY_REG(GraphCell::GraphImpl, GPUGraphImpl); GPUGraphImpl::GPUGraphImpl() : session_impl_(nullptr), @@ -291,4 +291,6 @@ std::vector GPUGraphImpl::GetOutputs() { } return result; } + +bool GPUGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) { return device_type == kGPU; } } // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.h b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.h index a4e43abf8a8..2e6c6c06834 100644 --- a/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.h +++ b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.h @@ -37,6 +37,8 @@ class GPUGraphImpl : public GraphCell::GraphImpl { std::vector GetInputs() override; std::vector GetOutputs() override; + bool CheckDeviceSupport(mindspore::DeviceType device_type) override; + private: Status InitEnv(); Status FinalizeEnv(); diff --git a/mindspore/ccsrc/cxx_api/graph/graph_impl.h b/mindspore/ccsrc/cxx_api/graph/graph_impl.h index 2b678e63409..626d64b23aa 100644 --- a/mindspore/ccsrc/cxx_api/graph/graph_impl.h +++ b/mindspore/ccsrc/cxx_api/graph/graph_impl.h @@ -42,6 +42,8 @@ class GraphCell::GraphImpl { virtual std::vector GetInputs() = 0; virtual std::vector GetOutputs() = 0; + virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0; + protected: std::shared_ptr graph_; std::shared_ptr graph_context_; diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc b/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc index e0b2fc0c0b6..30670e5f124 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc @@ -21,7 +21,7 @@ #include "include/api/context.h" #include "cxx_api/factory.h" #include "cxx_api/graph/acl/acl_env_guard.h" -#include "acl/acl_base.h" +#include "cxx_api/acl_utils.h" namespace mindspore { Status AclModel::Build() { @@ -112,7 +112,7 @@ Status AclModel::Resize(const std::vector &inputs, const std::vector(); - model_context_->MutableDeviceInfo().emplace_back(std::make_shared()); + model_context_->MutableDeviceInfo().emplace_back(std::make_shared()); } std::string input_shape_option; @@ -139,7 +139,7 @@ Status AclModel::Resize(const std::vector &inputs, const std::vectorCast(); + auto ascend310_info = device_infos[0]->Cast(); MS_EXCEPTION_IF_NULL(ascend310_info); ascend310_info->SetInputShape(input_shape_option); auto graph_cell_bak = std::move(graph_cell_); @@ -163,16 +163,15 @@ std::vector AclModel::GetOutputs() { return graph_cell_->GetOutputs(); } -bool AclModel::CheckModelSupport(enum ModelType model_type) { - const char *soc_name_c = aclrtGetSocName(); - if (soc_name_c == nullptr) { - return false; - } - std::string soc_name(soc_name_c); - if (soc_name.find("910") != std::string::npos) { +bool AclModel::CheckDeviceSupport(mindspore::DeviceType device_type) { + // for Ascend, only support kAscend and kAscend310 + if (device_type != kAscend && device_type != kAscend310) { return false; } + return IsAscendNo910Soc(); +} +bool AclModel::CheckModelSupport(enum ModelType model_type) { static const std::set kSupportedModelMap = {kMindIR, kOM}; auto iter = kSupportedModelMap.find(model_type); if (iter == kSupportedModelMap.end()) { diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model.h b/mindspore/ccsrc/cxx_api/model/acl/acl_model.h index 79b93a72edb..e2bd9215dbd 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model.h +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model.h @@ -43,6 +43,7 @@ class AclModel : public ModelImpl { std::vector GetInputs() override; std::vector GetOutputs() override; + bool CheckDeviceSupport(mindspore::DeviceType device_type) override; bool CheckModelSupport(enum ModelType model_type) override; private: diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.cc b/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.cc index c746b0726e5..10754f83d5d 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.cc @@ -29,7 +29,7 @@ #include "cxx_api/model/acl/acl_vm/acl_vm.h" namespace mindspore { -API_FACTORY_REG(ModelImpl, Ascend310, AclModelMulti); +API_FACTORY_REG(ModelImpl, AclModelMulti); namespace { std::map kDtypeMap = { diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc b/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc index ed186415f0a..32fcd384a3f 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc @@ -33,7 +33,7 @@ AclModelOptions::AclModelOptions(const std::shared_ptr &context) { if (device_infos.size() != 1) { return; } - auto ascend310_info = device_infos[0]->Cast(); + auto ascend310_info = device_infos[0]->Cast(); if (ascend310_info == nullptr) { return; } diff --git a/mindspore/ccsrc/cxx_api/model/model.cc b/mindspore/ccsrc/cxx_api/model/model.cc index 4087d971c47..17b881f2f5f 100644 --- a/mindspore/ccsrc/cxx_api/model/model.cc +++ b/mindspore/ccsrc/cxx_api/model/model.cc @@ -20,19 +20,6 @@ #include "utils/utils.h" namespace mindspore { -namespace { -std::string GetDeviceTypeString(enum DeviceType type) { - static const std::map kDeviceTypeStrs = { - {kCPU, "CPU"}, {kGPU, "GPU"}, {kKirinNPU, "KirinGPU"}, {kAscend910, "Ascend910"}, {kAscend310, "Ascend310"}, - }; - auto iter = kDeviceTypeStrs.find(type); - if (iter != kDeviceTypeStrs.end()) { - return iter->second; - } - - return "InvalidDeviceType" + std::to_string(static_cast(type)); -} -} // namespace Status Model::Build(GraphCell graph_cell, const std::shared_ptr &model_context, const std::shared_ptr &) { if (graph_cell.GetGraph() == nullptr) { @@ -50,7 +37,7 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr &model_ return kMCInvalidInput; } - std::string device_target = GetDeviceTypeString(device_info[0]->GetDeviceType()); + auto device_target = device_info[0]->GetDeviceType(); impl_ = Factory::Instance().Create(device_target); if (impl_ == nullptr) { MS_LOG(ERROR) << "Create session type " << device_target << " failed"; @@ -175,16 +162,10 @@ Model::Model() : impl_(nullptr) {} Model::~Model() {} bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) { - std::string device_type_str = GetDeviceTypeString(device_type); - if (!Factory::Instance().CheckModelSupport(device_type_str)) { - return false; - } - - auto check_model = Factory::Instance().Create(device_type_str); + auto check_model = Factory::Instance().Create(device_type); if (check_model == nullptr) { return false; } - return check_model->CheckModelSupport(model_type); } diff --git a/mindspore/ccsrc/cxx_api/model/model_impl.h b/mindspore/ccsrc/cxx_api/model/model_impl.h index 5de592d557b..baabb280ab2 100644 --- a/mindspore/ccsrc/cxx_api/model/model_impl.h +++ b/mindspore/ccsrc/cxx_api/model/model_impl.h @@ -44,7 +44,8 @@ class ModelImpl { virtual std::vector GetInputs() = 0; virtual std::vector GetOutputs() = 0; - virtual bool CheckModelSupport(enum ModelType model_type) { return false; } + virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0; + virtual bool CheckModelSupport(enum ModelType model_type) = 0; virtual Status Preprocess(const std::vector &inputs, std::vector *outputs); diff --git a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc index 911726038fa..0b463bd5ca0 100644 --- a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc +++ b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc @@ -20,14 +20,13 @@ #include "include/api/context.h" #include "utils/ms_context.h" #include "cxx_api/factory.h" +#if ENABLE_D +#include "cxx_api/acl_utils.h" +#endif namespace mindspore { // mindspore-serving check current package for version check with ModelImpl factory. -#if ENABLE_D -API_FACTORY_REG(ModelImpl, Ascend910, MsModel); -#elif ENABLE_GPU -API_FACTORY_REG(ModelImpl, GPU, MsModel); -#endif +API_FACTORY_REG(ModelImpl, MsModel); static std::string GenerateShapeKey(const std::vector> &dims) { std::string shape_key; @@ -171,18 +170,23 @@ uint32_t MsModel::GetDeviceID() const { return 0; } -bool MsModel::CheckModelSupport(enum ModelType model_type) { +bool MsModel::CheckDeviceSupport(enum DeviceType device_type) { #if ENABLE_D - const char *soc_name_c = aclrtGetSocName(); - if (soc_name_c == nullptr) { + // for Ascend, only support kAscend or kAscend910 + if (device_type != kAscend && device_type != kAscend910) { return false; } - std::string soc_name(soc_name_c); - if (soc_name.find("910") == std::string::npos) { + return IsAscend910Soc(); +#else + // otherwise, only support GPU + if (device_type != kGPU) { return false; } + return true; #endif +} +bool MsModel::CheckModelSupport(mindspore::ModelType model_type) { static const std::set kSupportedModelMap = {kMindIR}; auto iter = kSupportedModelMap.find(model_type); if (iter == kSupportedModelMap.end()) { diff --git a/mindspore/ccsrc/cxx_api/model/ms/ms_model.h b/mindspore/ccsrc/cxx_api/model/ms/ms_model.h index 78cefd4eae6..93c012a350d 100644 --- a/mindspore/ccsrc/cxx_api/model/ms/ms_model.h +++ b/mindspore/ccsrc/cxx_api/model/ms/ms_model.h @@ -44,6 +44,7 @@ class MsModel : public ModelImpl { std::vector GetInputs() override; std::vector GetOutputs() override; + bool CheckDeviceSupport(mindspore::DeviceType device_type) override; bool CheckModelSupport(enum ModelType model_type) override; private: diff --git a/mindspore/lite/micro/example/mnist_stm32f746/mnist_stm32f746/include/api/context.h b/mindspore/lite/micro/example/mnist_stm32f746/mnist_stm32f746/include/api/context.h index 0e3ddf29e85..8fcaf32fa1c 100755 --- a/mindspore/lite/micro/example/mnist_stm32f746/mnist_stm32f746/include/api/context.h +++ b/mindspore/lite/micro/example/mnist_stm32f746/mnist_stm32f746/include/api/context.h @@ -1,18 +1,18 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* Copyright 2020 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H #define MINDSPORE_INCLUDE_API_CONTEXT_H @@ -25,212 +25,431 @@ namespace mindspore { enum DeviceType { - kCPU = 0, - kGPU, - kKirinNPU, - kAscend910, - kAscend310, - // add new type here - kInvalidDeviceType = 100, + kCPU = 0, + kGPU, + kKirinNPU, + kAscend, + kAscend910, + kAscend310, + // add new type here + kInvalidDeviceType = 100, }; class Allocator; +class Delegate; class DeviceInfoContext; +/// \brief Context is used to store environment variables during execution. class MS_API Context { - public: - Context(); - ~Context() = default; +public: + struct Data; + Context(); + ~Context() = default; - void SetThreadNum(int32_t thread_num); - int32_t GetThreadNum() const; + /// \brief Set the number of threads at runtime. Only valid for Lite. + /// + /// \param[in] thread_num the number of threads at runtime. + void SetThreadNum(int32_t thread_num); - void SetAllocator(const std::shared_ptr &allocator); - std::shared_ptr GetAllocator() const; + /// \brief Get the current thread number setting. Only valid for Lite. + /// + /// \return The current thread number setting. + int32_t GetThreadNum() const; - std::vector> &MutableDeviceInfo(); + /// \brief Set the thread affinity to CPU cores. Only valid for Lite. + /// + /// \param[in] mode: 0: no affinities, 1: big cores first, 2: little cores first + void SetThreadAffinity(int mode); - private: - struct Data; - std::shared_ptr data_; + /// \brief Get the thread affinity of CPU cores. Only valid for Lite. + /// + /// \return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first + int GetThreadAffinityMode() const; + + /// \brief Set the thread lists to CPU cores. Only valid for Lite. + /// + /// \note If core_list and mode are set by SetThreadAffinity at the same time, the core_list is effective, but the + /// mode is not effective. + /// + /// \param[in] core_list: a vector of thread core lists. + void SetThreadAffinity(const std::vector &core_list); + + /// \brief Get the thread lists of CPU cores. Only valid for Lite. + /// + /// \return core_list: a vector of thread core lists. + std::vector GetThreadAffinityCoreList() const; + + /// \brief Set the status whether to perform model inference or training in parallel. Only valid for Lite. + /// + /// \param[in] is_parallel: true, parallel; false, not in parallel. + void SetEnableParallel(bool is_parallel); + + /// \brief Get the status whether to perform model inference or training in parallel. Only valid for Lite. + /// + /// \return Bool value that indicates whether in parallel. + bool GetEnableParallel() const; + + /// \brief Set Delegate to access third-party AI framework. Only valid for Lite. + /// + /// \param[in] Pointer to the custom delegate. + void SetDelegate(const std::shared_ptr &delegate); + + /// \brief Get the delegate of the third-party AI framework. Only valid for Lite. + /// + /// \return Pointer to the custom delegate. + std::shared_ptr GetDelegate() const; + + /// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports + /// heterogeneous scenarios with multiple members in the vector. + /// + /// \return Mutable reference of DeviceInfoContext vector in this context. + std::vector> &MutableDeviceInfo(); + +private: + std::shared_ptr data_; }; +/// \brief DeviceInfoContext defines different device contexts. class MS_API DeviceInfoContext : public std::enable_shared_from_this { - public: - struct Data; +public: + struct Data; - DeviceInfoContext(); - virtual ~DeviceInfoContext() = default; - virtual enum DeviceType GetDeviceType() const = 0; + DeviceInfoContext(); + virtual ~DeviceInfoContext() = default; - template - std::shared_ptr Cast() { - static_assert(std::is_base_of::value, "Wrong cast type."); - if (GetDeviceType() != T().GetDeviceType()) { - return nullptr; - } + /// \brief Get the type of this DeviceInfoContext. + /// + /// \return Type of this DeviceInfoContext. + virtual enum DeviceType GetDeviceType() const = 0; - return std::static_pointer_cast(shared_from_this()); - } + /// \brief A similar function to RTTI is provided when the -fno-rtti compilation option is turned on, which converts + /// DeviceInfoContext to a shared pointer of type T, and returns nullptr if the conversion fails. + /// + /// \param T Type + /// \return A pointer of type T after conversion. If the conversion fails, it will be nullptr. + template + std::shared_ptr Cast() { + static_assert(std::is_base_of::value, "Wrong cast type."); + if (GetDeviceType() != T().GetDeviceType()) { + return nullptr; + } - protected: - std::shared_ptr data_; + return std::static_pointer_cast(shared_from_this()); + } + /// \brief obtain provider's name + /// + /// \return provider's name. + std::string GetProvider() const; + /// \brief set provider's name. + /// + /// \param[in] provider define the provider's name. + + void SetProvider(const std::string &provider); + /// \brief obtain provider's device type. + /// + /// \return provider's device type. + + std::string GetProviderDevice() const; + /// \brief set provider's device type. + /// + /// \param[in] device define the provider's device type.EG: CPU. + void SetProviderDevice(const std::string &device); + + /// \brief set memory allocator. + /// + /// \param[in] allocator define the memory allocator which can be defined by user. + void SetAllocator(const std::shared_ptr &allocator); + + /// \brief obtain memory allocator. + /// + /// \return memory allocator. + std::shared_ptr GetAllocator() const; + +protected: + std::shared_ptr data_; }; +/// \brief Derived from DeviceInfoContext, The configuration of the model running on the CPU. This option is only valid +/// for MindSpore Lite. class MS_API CPUDeviceInfo : public DeviceInfoContext { - public: - enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; }; +public: + /// \brief Get the type of this DeviceInfoContext. + /// + /// \return Type of this DeviceInfoContext. + enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; }; - /// \brief Set the thread affinity to CPU cores. - /// - /// \param mode: 0: no affinities, 1: big cores first, 2: little cores first - void SetThreadAffinity(int mode); - int GetThreadAffinity() const; - void SetEnableFP16(bool is_fp16); - bool GetEnableFP16() const; + /// \brief Set enables to perform the float16 inference + /// + /// \param[in] is_fp16 Enable float16 inference or not. + void SetEnableFP16(bool is_fp16); + + /// \brief Get enables to perform the float16 inference + /// + /// \return Whether enable float16 inference. + bool GetEnableFP16() const; }; +/// \brief Derived from DeviceInfoContext, The configuration of the model running on the NPU. This option is only valid +/// for MindSpore Lite. class MS_API KirinNPUDeviceInfo : public DeviceInfoContext { - public: - enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; }; +public: + /// \brief Get the type of this DeviceInfoContext. + /// + /// \return Type of this DeviceInfoContext. + enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; }; - void SetFrequency(int frequency); - int GetFrequency() const; + /// \brief Set the NPU frequency. + /// + /// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme + /// performance), default as 3. + void SetFrequency(int frequency); + + /// \brief Get the NPU frequency. + /// + /// \return NPU frequency + int GetFrequency() const; }; +/// \brief Derived from DeviceInfoContext, The configuration of the model running on the GPU. class MS_API GPUDeviceInfo : public DeviceInfoContext { - public: - enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; }; +public: + /// \brief Get the type of this DeviceInfoContext. + /// + /// \return Type of this DeviceInfoContext. + enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; }; - void SetDeviceID(uint32_t device_id); - uint32_t GetDeviceID() const; + /// \brief Set device id. + /// + /// \param[in] device_id The device id. + void SetDeviceID(uint32_t device_id); - void SetGpuTrtInferMode(bool gpu_trt_infer_mode); - bool GetGpuTrtInferMode() const; + /// \brief Get the device id. + /// + /// \return The device id. + uint32_t GetDeviceID() const; - void SetEnableFP16(bool is_fp16); - bool GetEnableFP16() const; + /// \brief Get the distribution rank id. + /// + /// \return The device id. + int GetRankID() const; + + /// \brief Get the distribution group size. + /// + /// \return The device id. + int GetGroupSize() const; + + /// \brief Set the precision mode. + /// + /// \param[in] precision_mode Optional "origin", "fp16". "origin" is set as default. + inline void SetPrecisionMode(const std::string &precision_mode); + + /// \brief Get the precision mode. + /// + /// \return The precision mode. + inline std::string GetPrecisionMode() const; + + /// \brief Set enables to perform the float16 inference + /// + /// \param[in] is_fp16 Enable float16 inference or not. + void SetEnableFP16(bool is_fp16); + + /// \brief Get enables to perform the float16 inference + /// + /// \return Whether enable float16 inference. + bool GetEnableFP16() const; + +private: + void SetPrecisionMode(const std::vector &precision_mode); + std::vector GetPrecisionModeChar() const; }; -class MS_API Ascend910DeviceInfo : public DeviceInfoContext { - public: - enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; }; +void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) { + SetPrecisionMode(StringToChar(precision_mode)); +} +std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } - void SetDeviceID(uint32_t device_id); - uint32_t GetDeviceID() const; +/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is +/// invalid for MindSpore Lite. +class MS_API AscendDeviceInfo : public DeviceInfoContext { +public: + /// \brief Get the type of this DeviceInfoContext. + /// + /// \return Type of this DeviceInfoContext. + enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; }; + + /// \brief Set device id. + /// + /// \param[in] device_id The device id. + void SetDeviceID(uint32_t device_id); + + /// \brief Get the device id. + /// + /// \return The device id. + uint32_t GetDeviceID() const; + + /// \brief Set AIPP configuration file path. + /// + /// \param[in] cfg_path AIPP configuration file path. + inline void SetInsertOpConfigPath(const std::string &cfg_path); + + /// \brief Get AIPP configuration file path. + /// + /// \return AIPP configuration file path. + inline std::string GetInsertOpConfigPath() const; + + /// \brief Set format of model inputs. + /// + /// \param[in] format Optional "NCHW", "NHWC", etc. + inline void SetInputFormat(const std::string &format); + + /// \brief Get format of model inputs. + /// + /// \return The format of model inputs. + inline std::string GetInputFormat() const; + + /// \brief Set shape of model inputs. + /// + /// \param[in] shape e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1". + inline void SetInputShape(const std::string &shape); + + /// \brief Get shape of model inputs. + /// + /// \return The shape of model inputs. + inline std::string GetInputShape() const; + + /// \brief Set shape of model inputs. + /// + /// \param[in] shape e.g. {{1, {1,2,3,4}}, {2, {4,3,2,1}}} means the first input shape 1,2,3,4 and the second input + /// shape 4,3,2,1. + void SetInputShapeMap(const std::map> &shape); + + /// \brief Get shape of model inputs. + /// + /// \return The shape of model inputs. + std::map> GetInputShapeMap() const; + + void SetDynamicBatchSize(const std::vector &dynamic_batch_size); + inline std::string GetDynamicBatchSize() const; + + /// \brief Set the dynamic image size of model inputs. + /// + /// \param[in] image size hw e.g. "66,88;32,64" means h1:66,w1:88; h2:32,w2:64. + inline void SetDynamicImageSize(const std::string &dynamic_image_size); + + /// \brief Get dynamic image size of model inputs. + /// + /// \return The image size of model inputs. + inline std::string GetDynamicImageSize() const; + + /// \brief Set type of model outputs. + /// + /// \param[in] output_type FP32, UINT8 or FP16, default as FP32. + void SetOutputType(enum DataType output_type); + + /// \brief Get type of model outputs. + /// + /// \return The set type of model outputs. + enum DataType GetOutputType() const; + + /// \brief Set precision mode of model. + /// + /// \param[in] precision_mode Optional "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" and + /// "allow_mix_precision", "force_fp16" is set as default + inline void SetPrecisionMode(const std::string &precision_mode); + + /// \brief Get precision mode of model. + /// + /// \return The set type of model outputs + inline std::string GetPrecisionMode() const; + + /// \brief Set op select implementation mode. + /// + /// \param[in] op_select_impl_mode Optional "high_performance" and "high_precision", "high_performance" is set as + /// default. + inline void SetOpSelectImplMode(const std::string &op_select_impl_mode); + + /// \brief Get op select implementation mode. + /// + /// \return The set op select implementation mode. + inline std::string GetOpSelectImplMode() const; + + inline void SetFusionSwitchConfigPath(const std::string &cfg_path); + inline std::string GetFusionSwitchConfigPath() const; + + // Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize" + inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode); + inline std::string GetBufferOptimizeMode() const; + +private: + void SetInsertOpConfigPath(const std::vector &cfg_path); + std::vector GetInsertOpConfigPathChar() const; + + void SetInputFormat(const std::vector &format); + std::vector GetInputFormatChar() const; + + void SetInputShape(const std::vector &shape); + std::vector GetInputShapeChar() const; + + std::vector GetDynamicBatchSizeChar() const; + + void SetDynamicImageSize(const std::vector &dynamic_image_size); + std::vector GetDynamicImageSizeChar() const; + + void SetPrecisionMode(const std::vector &precision_mode); + std::vector GetPrecisionModeChar() const; + + void SetOpSelectImplMode(const std::vector &op_select_impl_mode); + std::vector GetOpSelectImplModeChar() const; + + void SetFusionSwitchConfigPath(const std::vector &cfg_path); + std::vector GetFusionSwitchConfigPathChar() const; + + void SetBufferOptimizeMode(const std::vector &buffer_optimize_mode); + std::vector GetBufferOptimizeModeChar() const; }; -class MS_API Ascend310DeviceInfo : public DeviceInfoContext { - public: - enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; }; +using Ascend310DeviceInfo = AscendDeviceInfo; +using Ascend910DeviceInfo = AscendDeviceInfo; - void SetDeviceID(uint32_t device_id); - uint32_t GetDeviceID() const; - - inline void SetDumpConfigPath(const std::string &cfg_path); - inline std::string GetDumpConfigPath() const; - - // aipp config file - inline void SetInsertOpConfigPath(const std::string &cfg_path); - inline std::string GetInsertOpConfigPath() const; - - // nchw or nhwc - inline void SetInputFormat(const std::string &format); - inline std::string GetInputFormat() const; - - // Mandatory while dynamic batch: e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1" - inline void SetInputShape(const std::string &shape); - inline std::string GetInputShape() const; - - void SetInputShapeMap(const std::map> &shape); - std::map> GetInputShapeMap() const; - - void SetDynamicBatchSize(const std::vector &dynamic_batch_size); - inline std::string GetDynamicBatchSize() const; - - // FP32, UINT8 or FP16, default as FP32 - void SetOutputType(enum DataType output_type); - enum DataType GetOutputType() const; - - // "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype", default as "force_fp16" - inline void SetPrecisionMode(const std::string &precision_mode); - inline std::string GetPrecisionMode() const; - - // Optional "high_performance" and "high_precision", "high_performance" is set as default - inline void SetOpSelectImplMode(const std::string &op_select_impl_mode); - inline std::string GetOpSelectImplMode() const; - - inline void SetFusionSwitchConfigPath(const std::string &cfg_path); - inline std::string GetFusionSwitchConfigPath() const; - - // Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize" - inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode); - inline std::string GetBufferOptimizeMode() const; - - private: - void SetDumpConfigPath(const std::vector &cfg_path); - std::vector GetDumpConfigPathChar() const; - - void SetInsertOpConfigPath(const std::vector &cfg_path); - std::vector GetInsertOpConfigPathChar() const; - - void SetInputFormat(const std::vector &format); - std::vector GetInputFormatChar() const; - - void SetInputShape(const std::vector &shape); - std::vector GetInputShapeChar() const; - - std::vector GetDynamicBatchSizeChar() const; - - void SetPrecisionMode(const std::vector &precision_mode); - std::vector GetPrecisionModeChar() const; - - void SetOpSelectImplMode(const std::vector &op_select_impl_mode); - std::vector GetOpSelectImplModeChar() const; - - void SetFusionSwitchConfigPath(const std::vector &cfg_path); - std::vector GetFusionSwitchConfigPathChar() const; - - void SetBufferOptimizeMode(const std::vector &buffer_optimize_mode); - std::vector GetBufferOptimizeModeChar() const; -}; - -void Ascend310DeviceInfo::SetDumpConfigPath(const std::string &cfg_path) { SetDumpConfigPath(StringToChar(cfg_path)); } -std::string Ascend310DeviceInfo::GetDumpConfigPath() const { return CharToString(GetDumpConfigPathChar()); } - -void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { - SetInsertOpConfigPath(StringToChar(cfg_path)); +void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { + SetInsertOpConfigPath(StringToChar(cfg_path)); } -std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); } +std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); } -void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); } -std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); } +void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); } +std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); } -void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); } -std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); } +void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); } +std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); } -std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); } +std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); } -void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) { - SetPrecisionMode(StringToChar(precision_mode)); -} -std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } - -void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) { - SetOpSelectImplMode(StringToChar(op_select_impl_mode)); -} -std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); } - -void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) { - SetFusionSwitchConfigPath(StringToChar(cfg_path)); -} -std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const { - return CharToString(GetFusionSwitchConfigPathChar()); +void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) { + SetDynamicImageSize(StringToChar(dynamic_image_size)); } -void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) { - SetBufferOptimizeMode(StringToChar(buffer_optimize_mode)); +std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); } + +void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) { + SetPrecisionMode(StringToChar(precision_mode)); } -std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } +std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } + +void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) { + SetOpSelectImplMode(StringToChar(op_select_impl_mode)); +} +std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); } + +void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) { + SetFusionSwitchConfigPath(StringToChar(cfg_path)); +} +std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const { + return CharToString(GetFusionSwitchConfigPathChar()); +} + +void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) { + SetBufferOptimizeMode(StringToChar(buffer_optimize_mode)); +} +std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CONTEXT_H diff --git a/mindspore/lite/src/cxx_api/context.cc b/mindspore/lite/src/cxx_api/context.cc index 7e1ba01086e..9374c1fcb6d 100644 --- a/mindspore/lite/src/cxx_api/context.cc +++ b/mindspore/lite/src/cxx_api/context.cc @@ -317,13 +317,7 @@ std::vector GPUDeviceInfo::GetPrecisionModeChar() const { return ret; } -void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) { MS_LOG(ERROR) << "Unsupported Feature."; } -uint32_t Ascend910DeviceInfo::GetDeviceID() const { - MS_LOG(ERROR) << "Unsupported Feature."; - return 0; -} - -void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) { +void AscendDeviceInfo::SetDeviceID(uint32_t device_id) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return; @@ -331,7 +325,7 @@ void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) { data_->params[kModelOptionAscend310DeviceID] = device_id; } -uint32_t Ascend310DeviceInfo::GetDeviceID() const { +uint32_t AscendDeviceInfo::GetDeviceID() const { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return 0; @@ -339,14 +333,14 @@ uint32_t Ascend310DeviceInfo::GetDeviceID() const { return GetValue(data_, kModelOptionAscend310DeviceID); } -void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector &cfg_path) { +void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector &cfg_path) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return; } data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path); } -std::vector Ascend310DeviceInfo::GetInsertOpConfigPathChar() const { +std::vector AscendDeviceInfo::GetInsertOpConfigPathChar() const { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return std::vector(); @@ -355,7 +349,7 @@ std::vector Ascend310DeviceInfo::GetInsertOpConfigPathChar() const { return StringToChar(ref); } -void Ascend310DeviceInfo::SetInputFormat(const std::vector &format) { +void AscendDeviceInfo::SetInputFormat(const std::vector &format) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return; @@ -363,7 +357,7 @@ void Ascend310DeviceInfo::SetInputFormat(const std::vector &format) { data_->params[kModelOptionAscend310InputFormat] = CharToString(format); } -std::vector Ascend310DeviceInfo::GetInputFormatChar() const { +std::vector AscendDeviceInfo::GetInputFormatChar() const { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return std::vector(); @@ -372,14 +366,14 @@ std::vector Ascend310DeviceInfo::GetInputFormatChar() const { return StringToChar(ref); } -void Ascend310DeviceInfo::SetInputShape(const std::vector &shape) { +void AscendDeviceInfo::SetInputShape(const std::vector &shape) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return; } data_->params[kModelOptionAscend310InputShape] = CharToString(shape); } -std::vector Ascend310DeviceInfo::GetInputShapeChar() const { +std::vector AscendDeviceInfo::GetInputShapeChar() const { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return std::vector(); @@ -388,7 +382,7 @@ std::vector Ascend310DeviceInfo::GetInputShapeChar() const { return StringToChar(ref); } -void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector &dynamic_batch_size) { +void AscendDeviceInfo::SetDynamicBatchSize(const std::vector &dynamic_batch_size) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return; @@ -403,7 +397,7 @@ void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector &dynamic data_->params[kModelOptionAscend310DynamicBatchSize] = batchs; } -std::vector Ascend310DeviceInfo::GetDynamicBatchSizeChar() const { +std::vector AscendDeviceInfo::GetDynamicBatchSizeChar() const { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return std::vector(); @@ -412,7 +406,7 @@ std::vector Ascend310DeviceInfo::GetDynamicBatchSizeChar() const { return StringToChar(ref); } -void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector &dynamic_image_size) { +void AscendDeviceInfo::SetDynamicImageSize(const std::vector &dynamic_image_size) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return; @@ -420,7 +414,7 @@ void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector &dynamic_i data_->params[kModelOptionAscend310DynamicImageSize] = CharToString(dynamic_image_size); } -std::vector Ascend310DeviceInfo::GetDynamicImageSizeChar() const { +std::vector AscendDeviceInfo::GetDynamicImageSizeChar() const { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return std::vector(); @@ -429,7 +423,7 @@ std::vector Ascend310DeviceInfo::GetDynamicImageSizeChar() const { return StringToChar(ref); } -void Ascend310DeviceInfo::SetPrecisionMode(const std::vector &precision_mode) { +void AscendDeviceInfo::SetPrecisionMode(const std::vector &precision_mode) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return; @@ -437,7 +431,7 @@ void Ascend310DeviceInfo::SetPrecisionMode(const std::vector &precision_mo data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode); } -std::vector Ascend310DeviceInfo::GetPrecisionModeChar() const { +std::vector AscendDeviceInfo::GetPrecisionModeChar() const { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return std::vector(); @@ -446,7 +440,7 @@ std::vector Ascend310DeviceInfo::GetPrecisionModeChar() const { return StringToChar(ref); } -void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector &op_select_impl_mode) { +void AscendDeviceInfo::SetOpSelectImplMode(const std::vector &op_select_impl_mode) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return; @@ -454,7 +448,7 @@ void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector &op_select data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode); } -std::vector Ascend310DeviceInfo::GetOpSelectImplModeChar() const { +std::vector AscendDeviceInfo::GetOpSelectImplModeChar() const { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return std::vector(); @@ -463,14 +457,14 @@ std::vector Ascend310DeviceInfo::GetOpSelectImplModeChar() const { return StringToChar(ref); } -void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector &cfg_path) { +void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector &cfg_path) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return; } data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path); } -std::vector Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const { +std::vector AscendDeviceInfo::GetFusionSwitchConfigPathChar() const { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return std::vector(); @@ -479,7 +473,7 @@ std::vector Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const { return StringToChar(ref); } -void Ascend310DeviceInfo::SetInputShapeMap(const std::map> &shape) { +void AscendDeviceInfo::SetInputShapeMap(const std::map> &shape) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return; @@ -487,7 +481,7 @@ void Ascend310DeviceInfo::SetInputShapeMap(const std::map> data_->params[kModelOptionAscend310InputShapeMap] = shape; } -std::map> Ascend310DeviceInfo::GetInputShapeMap() const { +std::map> AscendDeviceInfo::GetInputShapeMap() const { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return std::map>(); @@ -495,7 +489,7 @@ std::map> Ascend310DeviceInfo::GetInputShapeMap() const { return GetValue>>(data_, kModelOptionAscend310InputShapeMap); } -void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) { +void AscendDeviceInfo::SetOutputType(enum DataType output_type) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return; @@ -503,7 +497,7 @@ void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) { data_->params[kModelOptionAscend310OutputType] = output_type; } -enum DataType Ascend310DeviceInfo::GetOutputType() const { +enum DataType AscendDeviceInfo::GetOutputType() const { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return DataType::kTypeUnknown; @@ -511,7 +505,7 @@ enum DataType Ascend310DeviceInfo::GetOutputType() const { return GetValue(data_, kModelOptionAscend310OutputType); } -void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector &buffer_optimize_mode) { +void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector &buffer_optimize_mode) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return; @@ -519,7 +513,7 @@ void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector &buffer_ data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode); } -std::vector Ascend310DeviceInfo::GetBufferOptimizeModeChar() const { +std::vector AscendDeviceInfo::GetBufferOptimizeModeChar() const { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; return std::vector(); diff --git a/mindspore/lite/src/cxx_api/converters.cc b/mindspore/lite/src/cxx_api/converters.cc index 332d9bfd8a8..097d8771b02 100644 --- a/mindspore/lite/src/cxx_api/converters.cc +++ b/mindspore/lite/src/cxx_api/converters.cc @@ -104,7 +104,7 @@ lite::InnerContext *ContextUtils::Convert(Context *context) { } else if (device->GetDeviceType() == kKirinNPU) { auto npu_context = device->Cast(); ret = AddNpuDevice(npu_context->GetFrequency(), inner_context.get()); - } else if (device->GetDeviceType() == kAscend310) { + } else if (device->GetDeviceType() == kAscend) { ret = AddAscend310Device(inner_context.get(), device.get()); } if (ret != kSuccess) { diff --git a/tests/st/cpp/common/common_test.cc b/tests/st/cpp/common/common_test.cc index 0321cc0f781..2e280780a26 100644 --- a/tests/st/cpp/common/common_test.cc +++ b/tests/st/cpp/common/common_test.cc @@ -71,11 +71,11 @@ std::shared_ptr Common::ContextAutoSet() { auto context = std::make_shared(); if (device_target_str == "Ascend310") { - auto ascend310_info = std::make_shared(); + auto ascend310_info = std::make_shared(); ascend310_info->SetDeviceID(device_id); context->MutableDeviceInfo().emplace_back(ascend310_info); } else if (device_target_str == "Ascend910") { - auto ascend310_info = std::make_shared(); + auto ascend310_info = std::make_shared(); ascend310_info->SetDeviceID(device_id); context->MutableDeviceInfo().emplace_back(ascend310_info); } else { diff --git a/tests/st/cpp/dataset/test_de.cc b/tests/st/cpp/dataset/test_de.cc index f55e92edadc..fac9295df00 100644 --- a/tests/st/cpp/dataset/test_de.cc +++ b/tests/st/cpp/dataset/test_de.cc @@ -101,7 +101,7 @@ TEST_F(TestDE, TestDvpp) { auto context = ContextAutoSet(); ASSERT_TRUE(context != nullptr); ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); - auto ascend310_info = context->MutableDeviceInfo()[0]->Cast(); + auto ascend310_info = context->MutableDeviceInfo()[0]->Cast(); ASSERT_TRUE(ascend310_info != nullptr); auto device_id = ascend310_info->GetDeviceID(); @@ -154,7 +154,7 @@ TEST_F(TestDE, TestDvppSinkMode) { auto context = ContextAutoSet(); ASSERT_TRUE(context != nullptr); ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); - auto ascend310_info = context->MutableDeviceInfo()[0]->Cast(); + auto ascend310_info = context->MutableDeviceInfo()[0]->Cast(); ASSERT_TRUE(ascend310_info != nullptr); auto device_id = ascend310_info->GetDeviceID(); @@ -202,7 +202,7 @@ TEST_F(TestDE, TestDvppDecodeResizeCropNormalize) { auto context = ContextAutoSet(); ASSERT_TRUE(context != nullptr); ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); - auto ascend310_info = context->MutableDeviceInfo()[0]->Cast(); + auto ascend310_info = context->MutableDeviceInfo()[0]->Cast(); ASSERT_TRUE(ascend310_info != nullptr); auto device_id = ascend310_info->GetDeviceID(); diff --git a/tests/st/cpp/model/test_dynamic_batch_size.cc b/tests/st/cpp/model/test_dynamic_batch_size.cc index 8f79fa18382..7ccca332bfb 100644 --- a/tests/st/cpp/model/test_dynamic_batch_size.cc +++ b/tests/st/cpp/model/test_dynamic_batch_size.cc @@ -38,7 +38,7 @@ TEST_F(TestDynamicBatchSize, InferMindIR) { auto context = ContextAutoSet(); ASSERT_TRUE(context != nullptr); ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); - auto ascend310_info = context->MutableDeviceInfo()[0]->Cast(); + auto ascend310_info = context->MutableDeviceInfo()[0]->Cast(); ASSERT_TRUE(ascend310_info != nullptr); std::map> input_shape; diff --git a/tests/st/cpp/model/test_zero_copy.cc b/tests/st/cpp/model/test_zero_copy.cc index 4b68ee86192..6982f569027 100644 --- a/tests/st/cpp/model/test_zero_copy.cc +++ b/tests/st/cpp/model/test_zero_copy.cc @@ -59,7 +59,7 @@ TEST_F(TestZeroCopy, TestMindIR) { auto context = ContextAutoSet(); ASSERT_TRUE(context != nullptr); ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); - auto ascend310_info = context->MutableDeviceInfo()[0]->Cast(); + auto ascend310_info = context->MutableDeviceInfo()[0]->Cast(); ASSERT_TRUE(ascend310_info != nullptr); ascend310_info->SetInsertOpConfigPath(aipp_path); auto device_id = ascend310_info->GetDeviceID(); @@ -107,7 +107,7 @@ TEST_F(TestZeroCopy, TestDeviceTensor) { auto context = ContextAutoSet(); ASSERT_TRUE(context != nullptr); ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); - auto ascend310_info = context->MutableDeviceInfo()[0]->Cast(); + auto ascend310_info = context->MutableDeviceInfo()[0]->Cast(); ASSERT_TRUE(ascend310_info != nullptr); ascend310_info->SetInsertOpConfigPath(aipp_path); auto device_id = ascend310_info->GetDeviceID(); diff --git a/tests/ut/cpp/cxx_api/context_test.cc b/tests/ut/cpp/cxx_api/context_test.cc index 288cc6eef59..7e50a2f4589 100644 --- a/tests/ut/cpp/cxx_api/context_test.cc +++ b/tests/ut/cpp/cxx_api/context_test.cc @@ -27,32 +27,27 @@ TEST_F(TestCxxApiContext, test_context_device_info_cast_SUCCESS) { std::shared_ptr cpu = std::make_shared(); std::shared_ptr gpu = std::make_shared(); std::shared_ptr kirin_npu = std::make_shared(); - std::shared_ptr ascend310 = std::make_shared(); - std::shared_ptr ascend910 = std::make_shared(); + std::shared_ptr ascend = std::make_shared(); ASSERT_TRUE(cpu->Cast() != nullptr); ASSERT_TRUE(gpu->Cast() != nullptr); ASSERT_TRUE(kirin_npu->Cast() != nullptr); - ASSERT_TRUE(ascend310->Cast() != nullptr); - ASSERT_TRUE(ascend910->Cast() != nullptr); + ASSERT_TRUE(ascend->Cast() != nullptr); } TEST_F(TestCxxApiContext, test_context_device_info_cast_FAILED) { std::shared_ptr cpu = std::make_shared(); std::shared_ptr gpu = std::make_shared(); std::shared_ptr kirin_npu = std::make_shared(); - std::shared_ptr ascend310 = std::make_shared(); - std::shared_ptr ascend910 = std::make_shared(); + std::shared_ptr ascend = std::make_shared(); ASSERT_TRUE(cpu->Cast() == nullptr); ASSERT_TRUE(kirin_npu->Cast() == nullptr); - ASSERT_TRUE(ascend310->Cast() == nullptr); - ASSERT_TRUE(ascend910->Cast() == nullptr); + ASSERT_TRUE(ascend->Cast() == nullptr); ASSERT_TRUE(gpu->Cast() == nullptr); ASSERT_TRUE(kirin_npu->Cast() == nullptr); - ASSERT_TRUE(ascend310->Cast() == nullptr); - ASSERT_TRUE(ascend910->Cast() == nullptr); + ASSERT_TRUE(ascend->Cast() == nullptr); } TEST_F(TestCxxApiContext, test_context_get_set_SUCCESS) { @@ -86,7 +81,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) { std::string option_9_ans = "1,2,3,4,5"; auto context = std::make_shared(); - std::shared_ptr ascend310 = std::make_shared(); + std::shared_ptr ascend310 = std::make_shared(); ascend310->SetInputShape(option_1); ascend310->SetInsertOpConfigPath(option_2); ascend310->SetOpSelectImplMode(option_3); @@ -99,7 +94,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) { context->MutableDeviceInfo().push_back(ascend310); ASSERT_EQ(context->MutableDeviceInfo().size(), 1); - auto ctx = context->MutableDeviceInfo()[0]->Cast(); + auto ctx = context->MutableDeviceInfo()[0]->Cast(); ASSERT_TRUE(ctx != nullptr); ASSERT_EQ(ascend310->GetInputShape(), option_1); ASSERT_EQ(ascend310->GetInsertOpConfigPath(), option_2); @@ -113,7 +108,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) { } TEST_F(TestCxxApiContext, test_context_ascend310_context_default_value_SUCCESS) { - auto ctx = std::make_shared(); + auto ctx = std::make_shared(); ASSERT_EQ(ctx->GetOpSelectImplMode(), ""); } } // namespace mindspore