C++ inference interface Ascend310DeviceInfo+Ascend910DeviceInfo->AscendDeviceInfo

This commit is contained in:
xuyongfei 2021-12-09 15:20:31 +08:00
parent 87674cf3bd
commit 0a0419498f
28 changed files with 632 additions and 374 deletions

View File

@ -28,6 +28,7 @@ enum DeviceType {
kCPU = 0, kCPU = 0,
kGPU, kGPU,
kKirinNPU, kKirinNPU,
kAscend,
kAscend910, kAscend910,
kAscend310, kAscend310,
// add new type here // add new type here
@ -287,34 +288,14 @@ void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
} }
std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend910. This option is
/// invalid for MindSpore Lite.
class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
public:
/// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
/// \brief Set device id.
///
/// \param[in] device_id The device id.
void SetDeviceID(uint32_t device_id);
/// \brief Get the device id.
///
/// \return The device id.
uint32_t GetDeviceID() const;
};
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is /// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is
/// invalid for MindSpore Lite. /// invalid for MindSpore Lite.
class MS_API Ascend310DeviceInfo : public DeviceInfoContext { class MS_API AscendDeviceInfo : public DeviceInfoContext {
public: public:
/// \brief Get the type of this DeviceInfoContext. /// \brief Get the type of this DeviceInfoContext.
/// ///
/// \return Type of this DeviceInfoContext. /// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; }; enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; };
/// \brief Set device id. /// \brief Set device id.
/// ///
@ -447,45 +428,48 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
std::vector<char> GetBufferOptimizeModeChar() const; std::vector<char> GetBufferOptimizeModeChar() const;
}; };
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { using Ascend310DeviceInfo = AscendDeviceInfo;
using Ascend910DeviceInfo = AscendDeviceInfo;
void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
SetInsertOpConfigPath(StringToChar(cfg_path)); SetInsertOpConfigPath(StringToChar(cfg_path));
} }
std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); } std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); } void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); } std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); } void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); } std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); } std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
void Ascend310DeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) { void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) {
SetDynamicImageSize(StringToChar(dynamic_image_size)); SetDynamicImageSize(StringToChar(dynamic_image_size));
} }
std::string Ascend310DeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); } std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); }
void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) { void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
SetPrecisionMode(StringToChar(precision_mode)); SetPrecisionMode(StringToChar(precision_mode));
} }
std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) { void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
SetOpSelectImplMode(StringToChar(op_select_impl_mode)); SetOpSelectImplMode(StringToChar(op_select_impl_mode));
} }
std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); } std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) { void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
SetFusionSwitchConfigPath(StringToChar(cfg_path)); SetFusionSwitchConfigPath(StringToChar(cfg_path));
} }
std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const { std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const {
return CharToString(GetFusionSwitchConfigPathChar()); return CharToString(GetFusionSwitchConfigPathChar());
} }
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) { void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
SetBufferOptimizeMode(StringToChar(buffer_optimize_mode)); SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
} }
std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H #endif // MINDSPORE_INCLUDE_API_CONTEXT_H

View File

@ -193,7 +193,7 @@ class MS_API Model {
/// \brief Inference model. /// \brief Inference model.
/// ///
/// \param[in] device_type Device typeoptions are kGPU, kAscend910, etc. /// \param[in] device_type Device typeoptions are kGPU, kAscend, kAscend910, etc.
/// \param[in] model_type The type of model file, options are ModelType::kMindIR, ModelType::kOM. /// \param[in] model_type The type of model file, options are ModelType::kMindIR, ModelType::kOM.
/// ///
/// \return Is supported or not. /// \return Is supported or not.

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H
#define MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H
#include <string>
#include "acl/acl_base.h"
namespace mindspore {
static inline bool IsAscend910Soc() {
const char *soc_name_c = aclrtGetSocName();
if (soc_name_c == nullptr) {
return false;
}
std::string soc_name(soc_name_c);
if (soc_name.find("910") == std::string::npos) {
return false;
}
return true;
}
static inline bool IsAscendNo910Soc() {
const char *soc_name_c = aclrtGetSocName();
if (soc_name_c == nullptr) {
return false;
}
std::string soc_name(soc_name_c);
if (soc_name.find("910") != std::string::npos) {
return false;
}
return true;
}
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H

View File

@ -175,55 +175,46 @@ std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) { void AscendDeviceInfo::SetDeviceID(uint32_t device_id) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend910DeviceID] = device_id;
}
uint32_t Ascend910DeviceInfo::GetDeviceID() const {
MS_EXCEPTION_IF_NULL(data_);
return GetValue<uint32_t>(data_, kModelOptionAscend910DeviceID);
}
void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310DeviceID] = device_id; data_->params[kModelOptionAscend310DeviceID] = device_id;
} }
uint32_t Ascend310DeviceInfo::GetDeviceID() const { uint32_t AscendDeviceInfo::GetDeviceID() const {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID); return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
} }
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) { void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path); data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
} }
std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const { std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InsertOpCfgPath); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InsertOpCfgPath);
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) { void AscendDeviceInfo::SetInputFormat(const std::vector<char> &format) {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310InputFormat] = CharToString(format); data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
} }
std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const { std::vector<char> AscendDeviceInfo::GetInputFormatChar() const {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputFormat); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputFormat);
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) { void AscendDeviceInfo::SetInputShape(const std::vector<char> &shape) {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310InputShape] = CharToString(shape); data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
} }
std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const { std::vector<char> AscendDeviceInfo::GetInputShapeChar() const {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputShape); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputShape);
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) { void AscendDeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
std::string batchs = ""; std::string batchs = "";
for (size_t i = 0; i < dynamic_batch_size.size(); ++i) { for (size_t i = 0; i < dynamic_batch_size.size(); ++i) {
@ -234,69 +225,69 @@ void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic
} }
data_->params[kModelOptionAscend310DynamicBatchSize] = batchs; data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
} }
std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const { std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicBatchSize); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicBatchSize);
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) { return; } void AscendDeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) { return; }
std::vector<char> Ascend310DeviceInfo::GetDynamicImageSizeChar() const { return std::vector<char>(); } std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const { return std::vector<char>(); }
void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) { void AscendDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode); data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
} }
std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const { std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310PrecisionMode); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310PrecisionMode);
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) { void AscendDeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode); data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
} }
std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const { std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310OpSelectImplMode); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310OpSelectImplMode);
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) { void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path); data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
} }
std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const { std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, KModelOptionAscend310FusionSwitchCfgPath); const std::string &ref = GetValue<std::string>(data_, KModelOptionAscend310FusionSwitchCfgPath);
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) { void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310InputShapeMap] = shape; data_->params[kModelOptionAscend310InputShapeMap] = shape;
} }
std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const { std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap); return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
} }
void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) { void AscendDeviceInfo::SetOutputType(enum DataType output_type) {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310OutputType] = output_type; data_->params[kModelOptionAscend310OutputType] = output_type;
} }
enum DataType Ascend310DeviceInfo::GetOutputType() const { enum DataType AscendDeviceInfo::GetOutputType() const {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType); return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
} }
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) { void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode); data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
} }
std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const { std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310BufferOptimize); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310BufferOptimize);
return StringToChar(ref); return StringToChar(ref);

View File

@ -24,7 +24,31 @@
#include "utils/utils.h" #include "utils/utils.h"
namespace mindspore { namespace mindspore {
inline std::string g_device_target = "Default"; inline enum DeviceType g_device_target = kInvalidDeviceType;
static inline LogStream &operator<<(LogStream &stream, DeviceType device_type) {
switch (device_type) {
case kAscend:
stream << "Ascend";
break;
case kAscend910:
stream << "Ascend910";
break;
case kAscend310:
stream << "Ascend310";
break;
case kGPU:
stream << "GPU";
break;
case kCPU:
stream << "CPU";
break;
default:
stream << "[InvalidDeviceType: " << static_cast<int>(device_type) << "]";
break;
}
return stream;
}
template <class T> template <class T>
class Factory { class Factory {
@ -39,32 +63,24 @@ class Factory {
return instance; return instance;
} }
void Register(const std::string &device_name, U &&creator) { void Register(U &&creator) { creators_.push_back(creator); }
if (creators_.find(device_name) == creators_.end()) {
(void)creators_.emplace(device_name, creator); std::shared_ptr<T> Create(enum DeviceType device_type) {
for (auto &item : creators_) {
MS_EXCEPTION_IF_NULL(item);
auto val = item();
if (val->CheckDeviceSupport(device_type)) {
return val;
}
} }
} MS_LOG(WARNING) << "Unsupported device target " << device_type;
bool CheckModelSupport(const std::string &device_name) {
return std::any_of(creators_.begin(), creators_.end(),
[&device_name](const std::pair<std::string, U> &item) { return item.first == device_name; });
}
std::shared_ptr<T> Create(const std::string &device_name) {
auto iter = creators_.find(device_name);
if (creators_.end() != iter) {
MS_EXCEPTION_IF_NULL(iter->second);
return (iter->second)();
}
MS_LOG(ERROR) << "Unsupported device target " << device_name;
return nullptr; return nullptr;
} }
private: private:
Factory() = default; Factory() = default;
~Factory() = default; ~Factory() = default;
std::map<std::string, U> creators_; std::vector<U> creators_;
}; };
template <class T> template <class T>
@ -72,14 +88,12 @@ class Registrar {
using U = std::function<std::shared_ptr<T>()>; using U = std::function<std::shared_ptr<T>()>;
public: public:
Registrar(const std::string &device_name, U creator) { explicit Registrar(U creator) { Factory<T>::Instance().Register(std::move(creator)); }
Factory<T>::Instance().Register(device_name, std::move(creator));
}
~Registrar() = default; ~Registrar() = default;
}; };
#define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \ #define API_FACTORY_REG(BASE_CLASS, DERIVE_CLASS) \
static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \ static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_reg( \
#DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); }); []() { return std::make_shared<DERIVE_CLASS>(); });
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H #endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H

View File

@ -18,9 +18,10 @@
#include "cxx_api/model/acl/model_converter.h" #include "cxx_api/model/acl/model_converter.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "mindspore/core/utils/convert_utils_base.h" #include "mindspore/core/utils/convert_utils_base.h"
#include "cxx_api/acl_utils.h"
namespace mindspore { namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, Ascend310, AclGraphImpl); API_FACTORY_REG(GraphCell::GraphImpl, AclGraphImpl);
AclGraphImpl::AclGraphImpl() AclGraphImpl::AclGraphImpl()
: init_flag_(false), : init_flag_(false),
@ -231,4 +232,12 @@ Status AclGraphImpl::ConvertToOM() {
MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType(); MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType();
return kMCFailed; return kMCFailed;
} }
bool AclGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) {
// for Ascend, only support kAscend and kAscend310
if (device_type != kAscend && device_type != kAscend310) {
return false;
}
return IsAscendNo910Soc();
}
} // namespace mindspore } // namespace mindspore

View File

@ -37,6 +37,7 @@ class AclGraphImpl : public GraphCell::GraphImpl {
Status Load(uint32_t device_id) override; Status Load(uint32_t device_id) override;
std::vector<MSTensor> GetInputs() override; std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override; std::vector<MSTensor> GetOutputs() override;
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
private: private:
Status ConvertToOM(); Status ConvertToOM();

View File

@ -18,6 +18,7 @@
#include "include/api/context.h" #include "include/api/context.h"
#include "cxx_api/factory.h" #include "cxx_api/factory.h"
#include "cxx_api/akg_kernel_register.h" #include "cxx_api/akg_kernel_register.h"
#include "cxx_api/acl_utils.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/context/context_extends.h" #include "utils/context/context_extends.h"
#include "mindspore/core/base/base_ref_utils.h" #include "mindspore/core/base/base_ref_utils.h"
@ -30,7 +31,7 @@
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
namespace mindspore { namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl); API_FACTORY_REG(GraphCell::GraphImpl, AscendGraphImpl);
static constexpr const char *kHcclEnable = "MS_ENABLE_HCCL"; static constexpr const char *kHcclEnable = "MS_ENABLE_HCCL";
static constexpr const char *kHcclGroupFile = "PARA_GROUP_FILE"; static constexpr const char *kHcclGroupFile = "PARA_GROUP_FILE";
@ -382,6 +383,14 @@ std::shared_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::GetEnv
return acl_env; return acl_env;
} }
bool AscendGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) {
// for Ascend, only support kAscend and kAscend910
if (device_type != kAscend && device_type != kAscend910) {
return false;
}
return IsAscend910Soc();
}
std::map<uint32_t, std::weak_ptr<AscendGraphImpl::MsEnvGuard>> AscendGraphImpl::MsEnvGuard::global_ms_env_; std::map<uint32_t, std::weak_ptr<AscendGraphImpl::MsEnvGuard>> AscendGraphImpl::MsEnvGuard::global_ms_env_;
std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_; std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_;

View File

@ -39,6 +39,7 @@ class AscendGraphImpl : public GraphCell::GraphImpl {
Status Load(uint32_t device_id) override; Status Load(uint32_t device_id) override;
std::vector<MSTensor> GetInputs() override; std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override; std::vector<MSTensor> GetOutputs() override;
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
private: private:
class MsEnvGuard; class MsEnvGuard;

View File

@ -26,7 +26,7 @@
#include "runtime/device/gpu/cuda_driver.h" #include "runtime/device/gpu/cuda_driver.h"
namespace mindspore { namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl); API_FACTORY_REG(GraphCell::GraphImpl, GPUGraphImpl);
GPUGraphImpl::GPUGraphImpl() GPUGraphImpl::GPUGraphImpl()
: session_impl_(nullptr), : session_impl_(nullptr),
@ -291,4 +291,6 @@ std::vector<MSTensor> GPUGraphImpl::GetOutputs() {
} }
return result; return result;
} }
bool GPUGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) { return device_type == kGPU; }
} // namespace mindspore } // namespace mindspore

View File

@ -37,6 +37,8 @@ class GPUGraphImpl : public GraphCell::GraphImpl {
std::vector<MSTensor> GetInputs() override; std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override; std::vector<MSTensor> GetOutputs() override;
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
private: private:
Status InitEnv(); Status InitEnv();
Status FinalizeEnv(); Status FinalizeEnv();

View File

@ -42,6 +42,8 @@ class GraphCell::GraphImpl {
virtual std::vector<MSTensor> GetInputs() = 0; virtual std::vector<MSTensor> GetInputs() = 0;
virtual std::vector<MSTensor> GetOutputs() = 0; virtual std::vector<MSTensor> GetOutputs() = 0;
virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0;
protected: protected:
std::shared_ptr<Graph> graph_; std::shared_ptr<Graph> graph_;
std::shared_ptr<Context> graph_context_; std::shared_ptr<Context> graph_context_;

View File

@ -21,7 +21,7 @@
#include "include/api/context.h" #include "include/api/context.h"
#include "cxx_api/factory.h" #include "cxx_api/factory.h"
#include "cxx_api/graph/acl/acl_env_guard.h" #include "cxx_api/graph/acl/acl_env_guard.h"
#include "acl/acl_base.h" #include "cxx_api/acl_utils.h"
namespace mindspore { namespace mindspore {
Status AclModel::Build() { Status AclModel::Build() {
@ -112,7 +112,7 @@ Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<s
if (model_context_ == nullptr) { if (model_context_ == nullptr) {
model_context_ = std::make_shared<Context>(); model_context_ = std::make_shared<Context>();
model_context_->MutableDeviceInfo().emplace_back(std::make_shared<Ascend310DeviceInfo>()); model_context_->MutableDeviceInfo().emplace_back(std::make_shared<AscendDeviceInfo>());
} }
std::string input_shape_option; std::string input_shape_option;
@ -139,7 +139,7 @@ Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<s
MS_LOG(ERROR) << "Invalid model context, only single device info is supported."; MS_LOG(ERROR) << "Invalid model context, only single device info is supported.";
return kMCInvalidArgs; return kMCInvalidArgs;
} }
auto ascend310_info = device_infos[0]->Cast<Ascend310DeviceInfo>(); auto ascend310_info = device_infos[0]->Cast<AscendDeviceInfo>();
MS_EXCEPTION_IF_NULL(ascend310_info); MS_EXCEPTION_IF_NULL(ascend310_info);
ascend310_info->SetInputShape(input_shape_option); ascend310_info->SetInputShape(input_shape_option);
auto graph_cell_bak = std::move(graph_cell_); auto graph_cell_bak = std::move(graph_cell_);
@ -163,16 +163,15 @@ std::vector<MSTensor> AclModel::GetOutputs() {
return graph_cell_->GetOutputs(); return graph_cell_->GetOutputs();
} }
bool AclModel::CheckModelSupport(enum ModelType model_type) { bool AclModel::CheckDeviceSupport(mindspore::DeviceType device_type) {
const char *soc_name_c = aclrtGetSocName(); // for Ascend, only support kAscend and kAscend310
if (soc_name_c == nullptr) { if (device_type != kAscend && device_type != kAscend310) {
return false;
}
std::string soc_name(soc_name_c);
if (soc_name.find("910") != std::string::npos) {
return false; return false;
} }
return IsAscendNo910Soc();
}
bool AclModel::CheckModelSupport(enum ModelType model_type) {
static const std::set<ModelType> kSupportedModelMap = {kMindIR, kOM}; static const std::set<ModelType> kSupportedModelMap = {kMindIR, kOM};
auto iter = kSupportedModelMap.find(model_type); auto iter = kSupportedModelMap.find(model_type);
if (iter == kSupportedModelMap.end()) { if (iter == kSupportedModelMap.end()) {

View File

@ -43,6 +43,7 @@ class AclModel : public ModelImpl {
std::vector<MSTensor> GetInputs() override; std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override; std::vector<MSTensor> GetOutputs() override;
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
bool CheckModelSupport(enum ModelType model_type) override; bool CheckModelSupport(enum ModelType model_type) override;
private: private:

View File

@ -29,7 +29,7 @@
#include "cxx_api/model/acl/acl_vm/acl_vm.h" #include "cxx_api/model/acl/acl_vm/acl_vm.h"
namespace mindspore { namespace mindspore {
API_FACTORY_REG(ModelImpl, Ascend310, AclModelMulti); API_FACTORY_REG(ModelImpl, AclModelMulti);
namespace { namespace {
std::map<DataType, size_t> kDtypeMap = { std::map<DataType, size_t> kDtypeMap = {

View File

@ -33,7 +33,7 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
if (device_infos.size() != 1) { if (device_infos.size() != 1) {
return; return;
} }
auto ascend310_info = device_infos[0]->Cast<Ascend310DeviceInfo>(); auto ascend310_info = device_infos[0]->Cast<AscendDeviceInfo>();
if (ascend310_info == nullptr) { if (ascend310_info == nullptr) {
return; return;
} }

View File

@ -20,19 +20,6 @@
#include "utils/utils.h" #include "utils/utils.h"
namespace mindspore { namespace mindspore {
namespace {
std::string GetDeviceTypeString(enum DeviceType type) {
static const std::map<enum DeviceType, std::string> kDeviceTypeStrs = {
{kCPU, "CPU"}, {kGPU, "GPU"}, {kKirinNPU, "KirinGPU"}, {kAscend910, "Ascend910"}, {kAscend310, "Ascend310"},
};
auto iter = kDeviceTypeStrs.find(type);
if (iter != kDeviceTypeStrs.end()) {
return iter->second;
}
return "InvalidDeviceType" + std::to_string(static_cast<int>(type));
}
} // namespace
Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_context, Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_context,
const std::shared_ptr<TrainCfg> &) { const std::shared_ptr<TrainCfg> &) {
if (graph_cell.GetGraph() == nullptr) { if (graph_cell.GetGraph() == nullptr) {
@ -50,7 +37,7 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_
return kMCInvalidInput; return kMCInvalidInput;
} }
std::string device_target = GetDeviceTypeString(device_info[0]->GetDeviceType()); auto device_target = device_info[0]->GetDeviceType();
impl_ = Factory<ModelImpl>::Instance().Create(device_target); impl_ = Factory<ModelImpl>::Instance().Create(device_target);
if (impl_ == nullptr) { if (impl_ == nullptr) {
MS_LOG(ERROR) << "Create session type " << device_target << " failed"; MS_LOG(ERROR) << "Create session type " << device_target << " failed";
@ -175,16 +162,10 @@ Model::Model() : impl_(nullptr) {}
Model::~Model() {} Model::~Model() {}
bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) { bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) {
std::string device_type_str = GetDeviceTypeString(device_type); auto check_model = Factory<ModelImpl>::Instance().Create(device_type);
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type_str)) {
return false;
}
auto check_model = Factory<ModelImpl>::Instance().Create(device_type_str);
if (check_model == nullptr) { if (check_model == nullptr) {
return false; return false;
} }
return check_model->CheckModelSupport(model_type); return check_model->CheckModelSupport(model_type);
} }

View File

@ -44,7 +44,8 @@ class ModelImpl {
virtual std::vector<MSTensor> GetInputs() = 0; virtual std::vector<MSTensor> GetInputs() = 0;
virtual std::vector<MSTensor> GetOutputs() = 0; virtual std::vector<MSTensor> GetOutputs() = 0;
virtual bool CheckModelSupport(enum ModelType model_type) { return false; } virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0;
virtual bool CheckModelSupport(enum ModelType model_type) = 0;
virtual Status Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs); virtual Status Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);

View File

@ -20,14 +20,13 @@
#include "include/api/context.h" #include "include/api/context.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "cxx_api/factory.h" #include "cxx_api/factory.h"
#if ENABLE_D
#include "cxx_api/acl_utils.h"
#endif
namespace mindspore { namespace mindspore {
// mindspore-serving check current package for version check with ModelImpl factory. // mindspore-serving check current package for version check with ModelImpl factory.
#if ENABLE_D API_FACTORY_REG(ModelImpl, MsModel);
API_FACTORY_REG(ModelImpl, Ascend910, MsModel);
#elif ENABLE_GPU
API_FACTORY_REG(ModelImpl, GPU, MsModel);
#endif
static std::string GenerateShapeKey(const std::vector<std::vector<int64_t>> &dims) { static std::string GenerateShapeKey(const std::vector<std::vector<int64_t>> &dims) {
std::string shape_key; std::string shape_key;
@ -171,18 +170,23 @@ uint32_t MsModel::GetDeviceID() const {
return 0; return 0;
} }
bool MsModel::CheckModelSupport(enum ModelType model_type) { bool MsModel::CheckDeviceSupport(enum DeviceType device_type) {
#if ENABLE_D #if ENABLE_D
const char *soc_name_c = aclrtGetSocName(); // for Ascend, only support kAscend or kAscend910
if (soc_name_c == nullptr) { if (device_type != kAscend && device_type != kAscend910) {
return false; return false;
} }
std::string soc_name(soc_name_c); return IsAscend910Soc();
if (soc_name.find("910") == std::string::npos) { #else
// otherwise, only support GPU
if (device_type != kGPU) {
return false; return false;
} }
return true;
#endif #endif
}
bool MsModel::CheckModelSupport(mindspore::ModelType model_type) {
static const std::set<ModelType> kSupportedModelMap = {kMindIR}; static const std::set<ModelType> kSupportedModelMap = {kMindIR};
auto iter = kSupportedModelMap.find(model_type); auto iter = kSupportedModelMap.find(model_type);
if (iter == kSupportedModelMap.end()) { if (iter == kSupportedModelMap.end()) {

View File

@ -44,6 +44,7 @@ class MsModel : public ModelImpl {
std::vector<MSTensor> GetInputs() override; std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override; std::vector<MSTensor> GetOutputs() override;
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
bool CheckModelSupport(enum ModelType model_type) override; bool CheckModelSupport(enum ModelType model_type) override;
private: private:

View File

@ -1,18 +1,18 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
#define MINDSPORE_INCLUDE_API_CONTEXT_H #define MINDSPORE_INCLUDE_API_CONTEXT_H
@ -25,212 +25,431 @@
namespace mindspore { namespace mindspore {
enum DeviceType { enum DeviceType {
kCPU = 0, kCPU = 0,
kGPU, kGPU,
kKirinNPU, kKirinNPU,
kAscend910, kAscend,
kAscend310, kAscend910,
// add new type here kAscend310,
kInvalidDeviceType = 100, // add new type here
kInvalidDeviceType = 100,
}; };
class Allocator; class Allocator;
class Delegate;
class DeviceInfoContext; class DeviceInfoContext;
/// \brief Context is used to store environment variables during execution.
class MS_API Context { class MS_API Context {
public: public:
Context(); struct Data;
~Context() = default; Context();
~Context() = default;
void SetThreadNum(int32_t thread_num); /// \brief Set the number of threads at runtime. Only valid for Lite.
int32_t GetThreadNum() const; ///
/// \param[in] thread_num the number of threads at runtime.
void SetThreadNum(int32_t thread_num);
void SetAllocator(const std::shared_ptr<Allocator> &allocator); /// \brief Get the current thread number setting. Only valid for Lite.
std::shared_ptr<Allocator> GetAllocator() const; ///
/// \return The current thread number setting.
int32_t GetThreadNum() const;
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo(); /// \brief Set the thread affinity to CPU cores. Only valid for Lite.
///
/// \param[in] mode: 0: no affinities, 1: big cores first, 2: little cores first
void SetThreadAffinity(int mode);
private: /// \brief Get the thread affinity of CPU cores. Only valid for Lite.
struct Data; ///
std::shared_ptr<Data> data_; /// \return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first
int GetThreadAffinityMode() const;
/// \brief Set the thread lists to CPU cores. Only valid for Lite.
///
/// \note If core_list and mode are set by SetThreadAffinity at the same time, the core_list is effective, but the
/// mode is not effective.
///
/// \param[in] core_list: a vector of thread core lists.
void SetThreadAffinity(const std::vector<int> &core_list);
/// \brief Get the thread lists of CPU cores. Only valid for Lite.
///
/// \return core_list: a vector of thread core lists.
std::vector<int32_t> GetThreadAffinityCoreList() const;
/// \brief Set the status whether to perform model inference or training in parallel. Only valid for Lite.
///
/// \param[in] is_parallel: true, parallel; false, not in parallel.
void SetEnableParallel(bool is_parallel);
/// \brief Get the status whether to perform model inference or training in parallel. Only valid for Lite.
///
/// \return Bool value that indicates whether in parallel.
bool GetEnableParallel() const;
/// \brief Set Delegate to access third-party AI framework. Only valid for Lite.
///
/// \param[in] Pointer to the custom delegate.
void SetDelegate(const std::shared_ptr<Delegate> &delegate);
/// \brief Get the delegate of the third-party AI framework. Only valid for Lite.
///
/// \return Pointer to the custom delegate.
std::shared_ptr<Delegate> GetDelegate() const;
/// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports
/// heterogeneous scenarios with multiple members in the vector.
///
/// \return Mutable reference of DeviceInfoContext vector in this context.
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
private:
std::shared_ptr<Data> data_;
}; };
/// \brief DeviceInfoContext defines different device contexts.
class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> { class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
public: public:
struct Data; struct Data;
DeviceInfoContext(); DeviceInfoContext();
virtual ~DeviceInfoContext() = default; virtual ~DeviceInfoContext() = default;
virtual enum DeviceType GetDeviceType() const = 0;
template <class T> /// \brief Get the type of this DeviceInfoContext.
std::shared_ptr<T> Cast() { ///
static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type."); /// \return Type of this DeviceInfoContext.
if (GetDeviceType() != T().GetDeviceType()) { virtual enum DeviceType GetDeviceType() const = 0;
return nullptr;
}
return std::static_pointer_cast<T>(shared_from_this()); /// \brief A similar function to RTTI is provided when the -fno-rtti compilation option is turned on, which converts
} /// DeviceInfoContext to a shared pointer of type T, and returns nullptr if the conversion fails.
///
/// \param T Type
/// \return A pointer of type T after conversion. If the conversion fails, it will be nullptr.
template <class T>
std::shared_ptr<T> Cast() {
static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
if (GetDeviceType() != T().GetDeviceType()) {
return nullptr;
}
protected: return std::static_pointer_cast<T>(shared_from_this());
std::shared_ptr<Data> data_; }
/// \brief obtain provider's name
///
/// \return provider's name.
std::string GetProvider() const;
/// \brief set provider's name.
///
/// \param[in] provider define the provider's name.
void SetProvider(const std::string &provider);
/// \brief obtain provider's device type.
///
/// \return provider's device type.
std::string GetProviderDevice() const;
/// \brief set provider's device type.
///
/// \param[in] device define the provider's device type.EG: CPU.
void SetProviderDevice(const std::string &device);
/// \brief set memory allocator.
///
/// \param[in] allocator define the memory allocator which can be defined by user.
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
/// \brief obtain memory allocator.
///
/// \return memory allocator.
std::shared_ptr<Allocator> GetAllocator() const;
protected:
std::shared_ptr<Data> data_;
}; };
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the CPU. This option is only valid
/// for MindSpore Lite.
class MS_API CPUDeviceInfo : public DeviceInfoContext { class MS_API CPUDeviceInfo : public DeviceInfoContext {
public: public:
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; }; /// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
/// \brief Set the thread affinity to CPU cores. /// \brief Set enables to perform the float16 inference
/// ///
/// \param mode: 0: no affinities, 1: big cores first, 2: little cores first /// \param[in] is_fp16 Enable float16 inference or not.
void SetThreadAffinity(int mode); void SetEnableFP16(bool is_fp16);
int GetThreadAffinity() const;
void SetEnableFP16(bool is_fp16); /// \brief Get enables to perform the float16 inference
bool GetEnableFP16() const; ///
/// \return Whether enable float16 inference.
bool GetEnableFP16() const;
}; };
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the NPU. This option is only valid
/// for MindSpore Lite.
class MS_API KirinNPUDeviceInfo : public DeviceInfoContext { class MS_API KirinNPUDeviceInfo : public DeviceInfoContext {
public: public:
enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; }; /// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
void SetFrequency(int frequency); /// \brief Set the NPU frequency.
int GetFrequency() const; ///
/// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme
/// performance), default as 3.
void SetFrequency(int frequency);
/// \brief Get the NPU frequency.
///
/// \return NPU frequency
int GetFrequency() const;
}; };
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the GPU.
class MS_API GPUDeviceInfo : public DeviceInfoContext { class MS_API GPUDeviceInfo : public DeviceInfoContext {
public: public:
enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; }; /// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; };
void SetDeviceID(uint32_t device_id); /// \brief Set device id.
uint32_t GetDeviceID() const; ///
/// \param[in] device_id The device id.
void SetDeviceID(uint32_t device_id);
void SetGpuTrtInferMode(bool gpu_trt_infer_mode); /// \brief Get the device id.
bool GetGpuTrtInferMode() const; ///
/// \return The device id.
uint32_t GetDeviceID() const;
void SetEnableFP16(bool is_fp16); /// \brief Get the distribution rank id.
bool GetEnableFP16() const; ///
/// \return The device id.
int GetRankID() const;
/// \brief Get the distribution group size.
///
/// \return The device id.
int GetGroupSize() const;
/// \brief Set the precision mode.
///
/// \param[in] precision_mode Optional "origin", "fp16". "origin" is set as default.
inline void SetPrecisionMode(const std::string &precision_mode);
/// \brief Get the precision mode.
///
/// \return The precision mode.
inline std::string GetPrecisionMode() const;
/// \brief Set enables to perform the float16 inference
///
/// \param[in] is_fp16 Enable float16 inference or not.
void SetEnableFP16(bool is_fp16);
/// \brief Get enables to perform the float16 inference
///
/// \return Whether enable float16 inference.
bool GetEnableFP16() const;
private:
void SetPrecisionMode(const std::vector<char> &precision_mode);
std::vector<char> GetPrecisionModeChar() const;
}; };
class MS_API Ascend910DeviceInfo : public DeviceInfoContext { void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
public: SetPrecisionMode(StringToChar(precision_mode));
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; }; }
std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
void SetDeviceID(uint32_t device_id); /// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is
uint32_t GetDeviceID() const; /// invalid for MindSpore Lite.
class MS_API AscendDeviceInfo : public DeviceInfoContext {
public:
/// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; };
/// \brief Set device id.
///
/// \param[in] device_id The device id.
void SetDeviceID(uint32_t device_id);
/// \brief Get the device id.
///
/// \return The device id.
uint32_t GetDeviceID() const;
/// \brief Set AIPP configuration file path.
///
/// \param[in] cfg_path AIPP configuration file path.
inline void SetInsertOpConfigPath(const std::string &cfg_path);
/// \brief Get AIPP configuration file path.
///
/// \return AIPP configuration file path.
inline std::string GetInsertOpConfigPath() const;
/// \brief Set format of model inputs.
///
/// \param[in] format Optional "NCHW", "NHWC", etc.
inline void SetInputFormat(const std::string &format);
/// \brief Get format of model inputs.
///
/// \return The format of model inputs.
inline std::string GetInputFormat() const;
/// \brief Set shape of model inputs.
///
/// \param[in] shape e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1".
inline void SetInputShape(const std::string &shape);
/// \brief Get shape of model inputs.
///
/// \return The shape of model inputs.
inline std::string GetInputShape() const;
/// \brief Set shape of model inputs.
///
/// \param[in] shape e.g. {{1, {1,2,3,4}}, {2, {4,3,2,1}}} means the first input shape 1,2,3,4 and the second input
/// shape 4,3,2,1.
void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
/// \brief Get shape of model inputs.
///
/// \return The shape of model inputs.
std::map<int, std::vector<int>> GetInputShapeMap() const;
void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
inline std::string GetDynamicBatchSize() const;
/// \brief Set the dynamic image size of model inputs.
///
/// \param[in] image size hw e.g. "66,88;32,64" means h1:66,w1:88; h2:32,w2:64.
inline void SetDynamicImageSize(const std::string &dynamic_image_size);
/// \brief Get dynamic image size of model inputs.
///
/// \return The image size of model inputs.
inline std::string GetDynamicImageSize() const;
/// \brief Set type of model outputs.
///
/// \param[in] output_type FP32, UINT8 or FP16, default as FP32.
void SetOutputType(enum DataType output_type);
/// \brief Get type of model outputs.
///
/// \return The set type of model outputs.
enum DataType GetOutputType() const;
/// \brief Set precision mode of model.
///
/// \param[in] precision_mode Optional "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" and
/// "allow_mix_precision", "force_fp16" is set as default
inline void SetPrecisionMode(const std::string &precision_mode);
/// \brief Get precision mode of model.
///
/// \return The set type of model outputs
inline std::string GetPrecisionMode() const;
/// \brief Set op select implementation mode.
///
/// \param[in] op_select_impl_mode Optional "high_performance" and "high_precision", "high_performance" is set as
/// default.
inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
/// \brief Get op select implementation mode.
///
/// \return The set op select implementation mode.
inline std::string GetOpSelectImplMode() const;
inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
inline std::string GetFusionSwitchConfigPath() const;
// Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize"
inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode);
inline std::string GetBufferOptimizeMode() const;
private:
void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetInsertOpConfigPathChar() const;
void SetInputFormat(const std::vector<char> &format);
std::vector<char> GetInputFormatChar() const;
void SetInputShape(const std::vector<char> &shape);
std::vector<char> GetInputShapeChar() const;
std::vector<char> GetDynamicBatchSizeChar() const;
void SetDynamicImageSize(const std::vector<char> &dynamic_image_size);
std::vector<char> GetDynamicImageSizeChar() const;
void SetPrecisionMode(const std::vector<char> &precision_mode);
std::vector<char> GetPrecisionModeChar() const;
void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
std::vector<char> GetOpSelectImplModeChar() const;
void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetFusionSwitchConfigPathChar() const;
void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode);
std::vector<char> GetBufferOptimizeModeChar() const;
}; };
class MS_API Ascend310DeviceInfo : public DeviceInfoContext { using Ascend310DeviceInfo = AscendDeviceInfo;
public: using Ascend910DeviceInfo = AscendDeviceInfo;
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };
void SetDeviceID(uint32_t device_id); void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
uint32_t GetDeviceID() const; SetInsertOpConfigPath(StringToChar(cfg_path));
inline void SetDumpConfigPath(const std::string &cfg_path);
inline std::string GetDumpConfigPath() const;
// aipp config file
inline void SetInsertOpConfigPath(const std::string &cfg_path);
inline std::string GetInsertOpConfigPath() const;
// nchw or nhwc
inline void SetInputFormat(const std::string &format);
inline std::string GetInputFormat() const;
// Mandatory while dynamic batch: e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1"
inline void SetInputShape(const std::string &shape);
inline std::string GetInputShape() const;
void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
std::map<int, std::vector<int>> GetInputShapeMap() const;
void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
inline std::string GetDynamicBatchSize() const;
// FP32, UINT8 or FP16, default as FP32
void SetOutputType(enum DataType output_type);
enum DataType GetOutputType() const;
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype", default as "force_fp16"
inline void SetPrecisionMode(const std::string &precision_mode);
inline std::string GetPrecisionMode() const;
// Optional "high_performance" and "high_precision", "high_performance" is set as default
inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
inline std::string GetOpSelectImplMode() const;
inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
inline std::string GetFusionSwitchConfigPath() const;
// Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize"
inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode);
inline std::string GetBufferOptimizeMode() const;
private:
void SetDumpConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetDumpConfigPathChar() const;
void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetInsertOpConfigPathChar() const;
void SetInputFormat(const std::vector<char> &format);
std::vector<char> GetInputFormatChar() const;
void SetInputShape(const std::vector<char> &shape);
std::vector<char> GetInputShapeChar() const;
std::vector<char> GetDynamicBatchSizeChar() const;
void SetPrecisionMode(const std::vector<char> &precision_mode);
std::vector<char> GetPrecisionModeChar() const;
void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
std::vector<char> GetOpSelectImplModeChar() const;
void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetFusionSwitchConfigPathChar() const;
void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode);
std::vector<char> GetBufferOptimizeModeChar() const;
};
void Ascend310DeviceInfo::SetDumpConfigPath(const std::string &cfg_path) { SetDumpConfigPath(StringToChar(cfg_path)); }
std::string Ascend310DeviceInfo::GetDumpConfigPath() const { return CharToString(GetDumpConfigPathChar()); }
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
SetInsertOpConfigPath(StringToChar(cfg_path));
} }
std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); } std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); } void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); } std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); } void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); } std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); } std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) { void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) {
SetPrecisionMode(StringToChar(precision_mode)); SetDynamicImageSize(StringToChar(dynamic_image_size));
}
std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
SetOpSelectImplMode(StringToChar(op_select_impl_mode));
}
std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
SetFusionSwitchConfigPath(StringToChar(cfg_path));
}
std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const {
return CharToString(GetFusionSwitchConfigPathChar());
} }
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) { std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); }
SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
SetPrecisionMode(StringToChar(precision_mode));
} }
std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
SetOpSelectImplMode(StringToChar(op_select_impl_mode));
}
std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
SetFusionSwitchConfigPath(StringToChar(cfg_path));
}
std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const {
return CharToString(GetFusionSwitchConfigPathChar());
}
void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
}
std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H #endif // MINDSPORE_INCLUDE_API_CONTEXT_H

View File

@ -317,13 +317,7 @@ std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
return ret; return ret;
} }
void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) { MS_LOG(ERROR) << "Unsupported Feature."; } void AscendDeviceInfo::SetDeviceID(uint32_t device_id) {
uint32_t Ascend910DeviceInfo::GetDeviceID() const {
MS_LOG(ERROR) << "Unsupported Feature.";
return 0;
}
void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
@ -331,7 +325,7 @@ void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
data_->params[kModelOptionAscend310DeviceID] = device_id; data_->params[kModelOptionAscend310DeviceID] = device_id;
} }
uint32_t Ascend310DeviceInfo::GetDeviceID() const { uint32_t AscendDeviceInfo::GetDeviceID() const {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return 0; return 0;
@ -339,14 +333,14 @@ uint32_t Ascend310DeviceInfo::GetDeviceID() const {
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID); return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
} }
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) { void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
} }
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path); data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
} }
std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const { std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
@ -355,7 +349,7 @@ std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) { void AscendDeviceInfo::SetInputFormat(const std::vector<char> &format) {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
@ -363,7 +357,7 @@ void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
data_->params[kModelOptionAscend310InputFormat] = CharToString(format); data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
} }
std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const { std::vector<char> AscendDeviceInfo::GetInputFormatChar() const {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
@ -372,14 +366,14 @@ std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) { void AscendDeviceInfo::SetInputShape(const std::vector<char> &shape) {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
} }
data_->params[kModelOptionAscend310InputShape] = CharToString(shape); data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
} }
std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const { std::vector<char> AscendDeviceInfo::GetInputShapeChar() const {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
@ -388,7 +382,7 @@ std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) { void AscendDeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
@ -403,7 +397,7 @@ void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic
data_->params[kModelOptionAscend310DynamicBatchSize] = batchs; data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
} }
std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const { std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
@ -412,7 +406,7 @@ std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) { void AscendDeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
@ -420,7 +414,7 @@ void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_i
data_->params[kModelOptionAscend310DynamicImageSize] = CharToString(dynamic_image_size); data_->params[kModelOptionAscend310DynamicImageSize] = CharToString(dynamic_image_size);
} }
std::vector<char> Ascend310DeviceInfo::GetDynamicImageSizeChar() const { std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
@ -429,7 +423,7 @@ std::vector<char> Ascend310DeviceInfo::GetDynamicImageSizeChar() const {
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) { void AscendDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
@ -437,7 +431,7 @@ void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mo
data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode); data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
} }
std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const { std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
@ -446,7 +440,7 @@ std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) { void AscendDeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
@ -454,7 +448,7 @@ void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select
data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode); data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
} }
std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const { std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
@ -463,14 +457,14 @@ std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) { void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
} }
data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path); data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
} }
std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const { std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();
@ -479,7 +473,7 @@ std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
return StringToChar(ref); return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) { void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
@ -487,7 +481,7 @@ void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>>
data_->params[kModelOptionAscend310InputShapeMap] = shape; data_->params[kModelOptionAscend310InputShapeMap] = shape;
} }
std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const { std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::map<int, std::vector<int>>(); return std::map<int, std::vector<int>>();
@ -495,7 +489,7 @@ std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap); return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
} }
void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) { void AscendDeviceInfo::SetOutputType(enum DataType output_type) {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
@ -503,7 +497,7 @@ void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
data_->params[kModelOptionAscend310OutputType] = output_type; data_->params[kModelOptionAscend310OutputType] = output_type;
} }
enum DataType Ascend310DeviceInfo::GetOutputType() const { enum DataType AscendDeviceInfo::GetOutputType() const {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return DataType::kTypeUnknown; return DataType::kTypeUnknown;
@ -511,7 +505,7 @@ enum DataType Ascend310DeviceInfo::GetOutputType() const {
return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType); return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
} }
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) { void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return; return;
@ -519,7 +513,7 @@ void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_
data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode); data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
} }
std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const { std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const {
if (data_ == nullptr) { if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context."; MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>(); return std::vector<char>();

View File

@ -104,7 +104,7 @@ lite::InnerContext *ContextUtils::Convert(Context *context) {
} else if (device->GetDeviceType() == kKirinNPU) { } else if (device->GetDeviceType() == kKirinNPU) {
auto npu_context = device->Cast<KirinNPUDeviceInfo>(); auto npu_context = device->Cast<KirinNPUDeviceInfo>();
ret = AddNpuDevice(npu_context->GetFrequency(), inner_context.get()); ret = AddNpuDevice(npu_context->GetFrequency(), inner_context.get());
} else if (device->GetDeviceType() == kAscend310) { } else if (device->GetDeviceType() == kAscend) {
ret = AddAscend310Device(inner_context.get(), device.get()); ret = AddAscend310Device(inner_context.get(), device.get());
} }
if (ret != kSuccess) { if (ret != kSuccess) {

View File

@ -71,11 +71,11 @@ std::shared_ptr<mindspore::Context> Common::ContextAutoSet() {
auto context = std::make_shared<mindspore::Context>(); auto context = std::make_shared<mindspore::Context>();
if (device_target_str == "Ascend310") { if (device_target_str == "Ascend310") {
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>(); auto ascend310_info = std::make_shared<mindspore::AscendDeviceInfo>();
ascend310_info->SetDeviceID(device_id); ascend310_info->SetDeviceID(device_id);
context->MutableDeviceInfo().emplace_back(ascend310_info); context->MutableDeviceInfo().emplace_back(ascend310_info);
} else if (device_target_str == "Ascend910") { } else if (device_target_str == "Ascend910") {
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>(); auto ascend310_info = std::make_shared<mindspore::AscendDeviceInfo>();
ascend310_info->SetDeviceID(device_id); ascend310_info->SetDeviceID(device_id);
context->MutableDeviceInfo().emplace_back(ascend310_info); context->MutableDeviceInfo().emplace_back(ascend310_info);
} else { } else {

View File

@ -101,7 +101,7 @@ TEST_F(TestDE, TestDvpp) {
auto context = ContextAutoSet(); auto context = ContextAutoSet();
ASSERT_TRUE(context != nullptr); ASSERT_TRUE(context != nullptr);
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>(); auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
ASSERT_TRUE(ascend310_info != nullptr); ASSERT_TRUE(ascend310_info != nullptr);
auto device_id = ascend310_info->GetDeviceID(); auto device_id = ascend310_info->GetDeviceID();
@ -154,7 +154,7 @@ TEST_F(TestDE, TestDvppSinkMode) {
auto context = ContextAutoSet(); auto context = ContextAutoSet();
ASSERT_TRUE(context != nullptr); ASSERT_TRUE(context != nullptr);
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>(); auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
ASSERT_TRUE(ascend310_info != nullptr); ASSERT_TRUE(ascend310_info != nullptr);
auto device_id = ascend310_info->GetDeviceID(); auto device_id = ascend310_info->GetDeviceID();
@ -202,7 +202,7 @@ TEST_F(TestDE, TestDvppDecodeResizeCropNormalize) {
auto context = ContextAutoSet(); auto context = ContextAutoSet();
ASSERT_TRUE(context != nullptr); ASSERT_TRUE(context != nullptr);
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>(); auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
ASSERT_TRUE(ascend310_info != nullptr); ASSERT_TRUE(ascend310_info != nullptr);
auto device_id = ascend310_info->GetDeviceID(); auto device_id = ascend310_info->GetDeviceID();

View File

@ -38,7 +38,7 @@ TEST_F(TestDynamicBatchSize, InferMindIR) {
auto context = ContextAutoSet(); auto context = ContextAutoSet();
ASSERT_TRUE(context != nullptr); ASSERT_TRUE(context != nullptr);
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>(); auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
ASSERT_TRUE(ascend310_info != nullptr); ASSERT_TRUE(ascend310_info != nullptr);
std::map<int, std::vector<int>> input_shape; std::map<int, std::vector<int>> input_shape;

View File

@ -59,7 +59,7 @@ TEST_F(TestZeroCopy, TestMindIR) {
auto context = ContextAutoSet(); auto context = ContextAutoSet();
ASSERT_TRUE(context != nullptr); ASSERT_TRUE(context != nullptr);
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>(); auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
ASSERT_TRUE(ascend310_info != nullptr); ASSERT_TRUE(ascend310_info != nullptr);
ascend310_info->SetInsertOpConfigPath(aipp_path); ascend310_info->SetInsertOpConfigPath(aipp_path);
auto device_id = ascend310_info->GetDeviceID(); auto device_id = ascend310_info->GetDeviceID();
@ -107,7 +107,7 @@ TEST_F(TestZeroCopy, TestDeviceTensor) {
auto context = ContextAutoSet(); auto context = ContextAutoSet();
ASSERT_TRUE(context != nullptr); ASSERT_TRUE(context != nullptr);
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1); ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>(); auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
ASSERT_TRUE(ascend310_info != nullptr); ASSERT_TRUE(ascend310_info != nullptr);
ascend310_info->SetInsertOpConfigPath(aipp_path); ascend310_info->SetInsertOpConfigPath(aipp_path);
auto device_id = ascend310_info->GetDeviceID(); auto device_id = ascend310_info->GetDeviceID();

View File

@ -27,32 +27,27 @@ TEST_F(TestCxxApiContext, test_context_device_info_cast_SUCCESS) {
std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>(); std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>();
std::shared_ptr<DeviceInfoContext> gpu = std::make_shared<GPUDeviceInfo>(); std::shared_ptr<DeviceInfoContext> gpu = std::make_shared<GPUDeviceInfo>();
std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>(); std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>();
std::shared_ptr<DeviceInfoContext> ascend310 = std::make_shared<Ascend310DeviceInfo>(); std::shared_ptr<DeviceInfoContext> ascend = std::make_shared<AscendDeviceInfo>();
std::shared_ptr<DeviceInfoContext> ascend910 = std::make_shared<Ascend910DeviceInfo>();
ASSERT_TRUE(cpu->Cast<CPUDeviceInfo>() != nullptr); ASSERT_TRUE(cpu->Cast<CPUDeviceInfo>() != nullptr);
ASSERT_TRUE(gpu->Cast<GPUDeviceInfo>() != nullptr); ASSERT_TRUE(gpu->Cast<GPUDeviceInfo>() != nullptr);
ASSERT_TRUE(kirin_npu->Cast<KirinNPUDeviceInfo>() != nullptr); ASSERT_TRUE(kirin_npu->Cast<KirinNPUDeviceInfo>() != nullptr);
ASSERT_TRUE(ascend310->Cast<Ascend310DeviceInfo>() != nullptr); ASSERT_TRUE(ascend->Cast<AscendDeviceInfo>() != nullptr);
ASSERT_TRUE(ascend910->Cast<Ascend910DeviceInfo>() != nullptr);
} }
TEST_F(TestCxxApiContext, test_context_device_info_cast_FAILED) { TEST_F(TestCxxApiContext, test_context_device_info_cast_FAILED) {
std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>(); std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>();
std::shared_ptr<DeviceInfoContext> gpu = std::make_shared<GPUDeviceInfo>(); std::shared_ptr<DeviceInfoContext> gpu = std::make_shared<GPUDeviceInfo>();
std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>(); std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>();
std::shared_ptr<DeviceInfoContext> ascend310 = std::make_shared<Ascend310DeviceInfo>(); std::shared_ptr<DeviceInfoContext> ascend = std::make_shared<AscendDeviceInfo>();
std::shared_ptr<DeviceInfoContext> ascend910 = std::make_shared<Ascend910DeviceInfo>();
ASSERT_TRUE(cpu->Cast<GPUDeviceInfo>() == nullptr); ASSERT_TRUE(cpu->Cast<GPUDeviceInfo>() == nullptr);
ASSERT_TRUE(kirin_npu->Cast<GPUDeviceInfo>() == nullptr); ASSERT_TRUE(kirin_npu->Cast<GPUDeviceInfo>() == nullptr);
ASSERT_TRUE(ascend310->Cast<GPUDeviceInfo>() == nullptr); ASSERT_TRUE(ascend->Cast<GPUDeviceInfo>() == nullptr);
ASSERT_TRUE(ascend910->Cast<GPUDeviceInfo>() == nullptr);
ASSERT_TRUE(gpu->Cast<CPUDeviceInfo>() == nullptr); ASSERT_TRUE(gpu->Cast<CPUDeviceInfo>() == nullptr);
ASSERT_TRUE(kirin_npu->Cast<CPUDeviceInfo>() == nullptr); ASSERT_TRUE(kirin_npu->Cast<CPUDeviceInfo>() == nullptr);
ASSERT_TRUE(ascend310->Cast<CPUDeviceInfo>() == nullptr); ASSERT_TRUE(ascend->Cast<CPUDeviceInfo>() == nullptr);
ASSERT_TRUE(ascend910->Cast<CPUDeviceInfo>() == nullptr);
} }
TEST_F(TestCxxApiContext, test_context_get_set_SUCCESS) { TEST_F(TestCxxApiContext, test_context_get_set_SUCCESS) {
@ -86,7 +81,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) {
std::string option_9_ans = "1,2,3,4,5"; std::string option_9_ans = "1,2,3,4,5";
auto context = std::make_shared<Context>(); auto context = std::make_shared<Context>();
std::shared_ptr<Ascend310DeviceInfo> ascend310 = std::make_shared<Ascend310DeviceInfo>(); std::shared_ptr<AscendDeviceInfo> ascend310 = std::make_shared<AscendDeviceInfo>();
ascend310->SetInputShape(option_1); ascend310->SetInputShape(option_1);
ascend310->SetInsertOpConfigPath(option_2); ascend310->SetInsertOpConfigPath(option_2);
ascend310->SetOpSelectImplMode(option_3); ascend310->SetOpSelectImplMode(option_3);
@ -99,7 +94,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) {
context->MutableDeviceInfo().push_back(ascend310); context->MutableDeviceInfo().push_back(ascend310);
ASSERT_EQ(context->MutableDeviceInfo().size(), 1); ASSERT_EQ(context->MutableDeviceInfo().size(), 1);
auto ctx = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>(); auto ctx = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
ASSERT_TRUE(ctx != nullptr); ASSERT_TRUE(ctx != nullptr);
ASSERT_EQ(ascend310->GetInputShape(), option_1); ASSERT_EQ(ascend310->GetInputShape(), option_1);
ASSERT_EQ(ascend310->GetInsertOpConfigPath(), option_2); ASSERT_EQ(ascend310->GetInsertOpConfigPath(), option_2);
@ -113,7 +108,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) {
} }
TEST_F(TestCxxApiContext, test_context_ascend310_context_default_value_SUCCESS) { TEST_F(TestCxxApiContext, test_context_ascend310_context_default_value_SUCCESS) {
auto ctx = std::make_shared<Ascend310DeviceInfo>(); auto ctx = std::make_shared<AscendDeviceInfo>();
ASSERT_EQ(ctx->GetOpSelectImplMode(), ""); ASSERT_EQ(ctx->GetOpSelectImplMode(), "");
} }
} // namespace mindspore } // namespace mindspore