forked from mindspore-Ecosystem/mindspore
C++ inference interface Ascend310DeviceInfo+Ascend910DeviceInfo->AscendDeviceInfo
This commit is contained in:
parent
87674cf3bd
commit
0a0419498f
|
@ -28,6 +28,7 @@ enum DeviceType {
|
|||
kCPU = 0,
|
||||
kGPU,
|
||||
kKirinNPU,
|
||||
kAscend,
|
||||
kAscend910,
|
||||
kAscend310,
|
||||
// add new type here
|
||||
|
@ -287,34 +288,14 @@ void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
|||
}
|
||||
std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
||||
|
||||
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend910. This option is
|
||||
/// invalid for MindSpore Lite.
|
||||
class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
|
||||
public:
|
||||
/// \brief Get the type of this DeviceInfoContext.
|
||||
///
|
||||
/// \return Type of this DeviceInfoContext.
|
||||
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
|
||||
|
||||
/// \brief Set device id.
|
||||
///
|
||||
/// \param[in] device_id The device id.
|
||||
void SetDeviceID(uint32_t device_id);
|
||||
|
||||
/// \brief Get the device id.
|
||||
///
|
||||
/// \return The device id.
|
||||
uint32_t GetDeviceID() const;
|
||||
};
|
||||
|
||||
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is
|
||||
/// invalid for MindSpore Lite.
|
||||
class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
|
||||
class MS_API AscendDeviceInfo : public DeviceInfoContext {
|
||||
public:
|
||||
/// \brief Get the type of this DeviceInfoContext.
|
||||
///
|
||||
/// \return Type of this DeviceInfoContext.
|
||||
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };
|
||||
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; };
|
||||
|
||||
/// \brief Set device id.
|
||||
///
|
||||
|
@ -447,45 +428,48 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
|
|||
std::vector<char> GetBufferOptimizeModeChar() const;
|
||||
};
|
||||
|
||||
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
|
||||
using Ascend310DeviceInfo = AscendDeviceInfo;
|
||||
using Ascend910DeviceInfo = AscendDeviceInfo;
|
||||
|
||||
void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
|
||||
SetInsertOpConfigPath(StringToChar(cfg_path));
|
||||
}
|
||||
std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
|
||||
std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
|
||||
|
||||
void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
|
||||
std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
|
||||
void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
|
||||
std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
|
||||
|
||||
void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
|
||||
std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
|
||||
void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
|
||||
std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
|
||||
|
||||
std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
|
||||
std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
|
||||
|
||||
void Ascend310DeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) {
|
||||
void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) {
|
||||
SetDynamicImageSize(StringToChar(dynamic_image_size));
|
||||
}
|
||||
|
||||
std::string Ascend310DeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); }
|
||||
std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); }
|
||||
|
||||
void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
||||
void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
||||
SetPrecisionMode(StringToChar(precision_mode));
|
||||
}
|
||||
std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
||||
std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
||||
|
||||
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
|
||||
void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
|
||||
SetOpSelectImplMode(StringToChar(op_select_impl_mode));
|
||||
}
|
||||
std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
|
||||
std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
|
||||
|
||||
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
|
||||
void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
|
||||
SetFusionSwitchConfigPath(StringToChar(cfg_path));
|
||||
}
|
||||
std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const {
|
||||
std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const {
|
||||
return CharToString(GetFusionSwitchConfigPathChar());
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
|
||||
void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
|
||||
SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
|
||||
}
|
||||
std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
|
||||
std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
|
|
|
@ -193,7 +193,7 @@ class MS_API Model {
|
|||
|
||||
/// \brief Inference model.
|
||||
///
|
||||
/// \param[in] device_type Device type,options are kGPU, kAscend910, etc.
|
||||
/// \param[in] device_type Device type,options are kGPU, kAscend, kAscend910, etc.
|
||||
/// \param[in] model_type The type of model file, options are ModelType::kMindIR, ModelType::kOM.
|
||||
///
|
||||
/// \return Is supported or not.
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H
|
||||
|
||||
#include <string>
|
||||
#include "acl/acl_base.h"
|
||||
namespace mindspore {
|
||||
static inline bool IsAscend910Soc() {
|
||||
const char *soc_name_c = aclrtGetSocName();
|
||||
if (soc_name_c == nullptr) {
|
||||
return false;
|
||||
}
|
||||
std::string soc_name(soc_name_c);
|
||||
if (soc_name.find("910") == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static inline bool IsAscendNo910Soc() {
|
||||
const char *soc_name_c = aclrtGetSocName();
|
||||
if (soc_name_c == nullptr) {
|
||||
return false;
|
||||
}
|
||||
std::string soc_name(soc_name_c);
|
||||
if (soc_name.find("910") != std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H
|
|
@ -175,55 +175,46 @@ std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
|
|||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionAscend910DeviceID] = device_id;
|
||||
}
|
||||
uint32_t Ascend910DeviceInfo::GetDeviceID() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
return GetValue<uint32_t>(data_, kModelOptionAscend910DeviceID);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||
void AscendDeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionAscend310DeviceID] = device_id;
|
||||
}
|
||||
uint32_t Ascend310DeviceInfo::GetDeviceID() const {
|
||||
uint32_t AscendDeviceInfo::GetDeviceID() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
|
||||
void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
|
||||
}
|
||||
std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InsertOpCfgPath);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
|
||||
void AscendDeviceInfo::SetInputFormat(const std::vector<char> &format) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
|
||||
}
|
||||
std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetInputFormatChar() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputFormat);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) {
|
||||
void AscendDeviceInfo::SetInputShape(const std::vector<char> &shape) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
|
||||
}
|
||||
std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetInputShapeChar() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputShape);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
|
||||
void AscendDeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
std::string batchs = "";
|
||||
for (size_t i = 0; i < dynamic_batch_size.size(); ++i) {
|
||||
|
@ -234,69 +225,69 @@ void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic
|
|||
}
|
||||
data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
|
||||
}
|
||||
std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicBatchSize);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) { return; }
|
||||
void AscendDeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) { return; }
|
||||
|
||||
std::vector<char> Ascend310DeviceInfo::GetDynamicImageSizeChar() const { return std::vector<char>(); }
|
||||
std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const { return std::vector<char>(); }
|
||||
|
||||
void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
||||
void AscendDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
|
||||
}
|
||||
std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310PrecisionMode);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
|
||||
void AscendDeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
|
||||
}
|
||||
std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310OpSelectImplMode);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
|
||||
void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
|
||||
}
|
||||
std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
const std::string &ref = GetValue<std::string>(data_, KModelOptionAscend310FusionSwitchCfgPath);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
|
||||
void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionAscend310InputShapeMap] = shape;
|
||||
}
|
||||
std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
|
||||
std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
|
||||
void AscendDeviceInfo::SetOutputType(enum DataType output_type) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionAscend310OutputType] = output_type;
|
||||
}
|
||||
enum DataType Ascend310DeviceInfo::GetOutputType() const {
|
||||
enum DataType AscendDeviceInfo::GetOutputType() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
|
||||
void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
|
||||
}
|
||||
std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310BufferOptimize);
|
||||
return StringToChar(ref);
|
||||
|
|
|
@ -24,7 +24,31 @@
|
|||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
inline std::string g_device_target = "Default";
|
||||
inline enum DeviceType g_device_target = kInvalidDeviceType;
|
||||
|
||||
static inline LogStream &operator<<(LogStream &stream, DeviceType device_type) {
|
||||
switch (device_type) {
|
||||
case kAscend:
|
||||
stream << "Ascend";
|
||||
break;
|
||||
case kAscend910:
|
||||
stream << "Ascend910";
|
||||
break;
|
||||
case kAscend310:
|
||||
stream << "Ascend310";
|
||||
break;
|
||||
case kGPU:
|
||||
stream << "GPU";
|
||||
break;
|
||||
case kCPU:
|
||||
stream << "CPU";
|
||||
break;
|
||||
default:
|
||||
stream << "[InvalidDeviceType: " << static_cast<int>(device_type) << "]";
|
||||
break;
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
class Factory {
|
||||
|
@ -39,32 +63,24 @@ class Factory {
|
|||
return instance;
|
||||
}
|
||||
|
||||
void Register(const std::string &device_name, U &&creator) {
|
||||
if (creators_.find(device_name) == creators_.end()) {
|
||||
(void)creators_.emplace(device_name, creator);
|
||||
void Register(U &&creator) { creators_.push_back(creator); }
|
||||
|
||||
std::shared_ptr<T> Create(enum DeviceType device_type) {
|
||||
for (auto &item : creators_) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
auto val = item();
|
||||
if (val->CheckDeviceSupport(device_type)) {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool CheckModelSupport(const std::string &device_name) {
|
||||
return std::any_of(creators_.begin(), creators_.end(),
|
||||
[&device_name](const std::pair<std::string, U> &item) { return item.first == device_name; });
|
||||
}
|
||||
|
||||
std::shared_ptr<T> Create(const std::string &device_name) {
|
||||
auto iter = creators_.find(device_name);
|
||||
if (creators_.end() != iter) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
return (iter->second)();
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "Unsupported device target " << device_name;
|
||||
MS_LOG(WARNING) << "Unsupported device target " << device_type;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
Factory() = default;
|
||||
~Factory() = default;
|
||||
std::map<std::string, U> creators_;
|
||||
std::vector<U> creators_;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
|
@ -72,14 +88,12 @@ class Registrar {
|
|||
using U = std::function<std::shared_ptr<T>()>;
|
||||
|
||||
public:
|
||||
Registrar(const std::string &device_name, U creator) {
|
||||
Factory<T>::Instance().Register(device_name, std::move(creator));
|
||||
}
|
||||
explicit Registrar(U creator) { Factory<T>::Instance().Register(std::move(creator)); }
|
||||
~Registrar() = default;
|
||||
};
|
||||
|
||||
#define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \
|
||||
static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \
|
||||
#DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); });
|
||||
#define API_FACTORY_REG(BASE_CLASS, DERIVE_CLASS) \
|
||||
static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_reg( \
|
||||
[]() { return std::make_shared<DERIVE_CLASS>(); });
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H
|
||||
|
|
|
@ -18,9 +18,10 @@
|
|||
#include "cxx_api/model/acl/model_converter.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "mindspore/core/utils/convert_utils_base.h"
|
||||
#include "cxx_api/acl_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
API_FACTORY_REG(GraphCell::GraphImpl, Ascend310, AclGraphImpl);
|
||||
API_FACTORY_REG(GraphCell::GraphImpl, AclGraphImpl);
|
||||
|
||||
AclGraphImpl::AclGraphImpl()
|
||||
: init_flag_(false),
|
||||
|
@ -231,4 +232,12 @@ Status AclGraphImpl::ConvertToOM() {
|
|||
MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType();
|
||||
return kMCFailed;
|
||||
}
|
||||
|
||||
bool AclGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) {
|
||||
// for Ascend, only support kAscend and kAscend310
|
||||
if (device_type != kAscend && device_type != kAscend310) {
|
||||
return false;
|
||||
}
|
||||
return IsAscendNo910Soc();
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,6 +37,7 @@ class AclGraphImpl : public GraphCell::GraphImpl {
|
|||
Status Load(uint32_t device_id) override;
|
||||
std::vector<MSTensor> GetInputs() override;
|
||||
std::vector<MSTensor> GetOutputs() override;
|
||||
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
|
||||
|
||||
private:
|
||||
Status ConvertToOM();
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "include/api/context.h"
|
||||
#include "cxx_api/factory.h"
|
||||
#include "cxx_api/akg_kernel_register.h"
|
||||
#include "cxx_api/acl_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/context/context_extends.h"
|
||||
#include "mindspore/core/base/base_ref_utils.h"
|
||||
|
@ -30,7 +31,7 @@
|
|||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace mindspore {
|
||||
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl);
|
||||
API_FACTORY_REG(GraphCell::GraphImpl, AscendGraphImpl);
|
||||
|
||||
static constexpr const char *kHcclEnable = "MS_ENABLE_HCCL";
|
||||
static constexpr const char *kHcclGroupFile = "PARA_GROUP_FILE";
|
||||
|
@ -382,6 +383,14 @@ std::shared_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::GetEnv
|
|||
return acl_env;
|
||||
}
|
||||
|
||||
bool AscendGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) {
|
||||
// for Ascend, only support kAscend and kAscend910
|
||||
if (device_type != kAscend && device_type != kAscend910) {
|
||||
return false;
|
||||
}
|
||||
return IsAscend910Soc();
|
||||
}
|
||||
|
||||
std::map<uint32_t, std::weak_ptr<AscendGraphImpl::MsEnvGuard>> AscendGraphImpl::MsEnvGuard::global_ms_env_;
|
||||
std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_;
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ class AscendGraphImpl : public GraphCell::GraphImpl {
|
|||
Status Load(uint32_t device_id) override;
|
||||
std::vector<MSTensor> GetInputs() override;
|
||||
std::vector<MSTensor> GetOutputs() override;
|
||||
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
|
||||
|
||||
private:
|
||||
class MsEnvGuard;
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "runtime/device/gpu/cuda_driver.h"
|
||||
|
||||
namespace mindspore {
|
||||
API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl);
|
||||
API_FACTORY_REG(GraphCell::GraphImpl, GPUGraphImpl);
|
||||
|
||||
GPUGraphImpl::GPUGraphImpl()
|
||||
: session_impl_(nullptr),
|
||||
|
@ -291,4 +291,6 @@ std::vector<MSTensor> GPUGraphImpl::GetOutputs() {
|
|||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool GPUGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) { return device_type == kGPU; }
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,6 +37,8 @@ class GPUGraphImpl : public GraphCell::GraphImpl {
|
|||
std::vector<MSTensor> GetInputs() override;
|
||||
std::vector<MSTensor> GetOutputs() override;
|
||||
|
||||
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
|
||||
|
||||
private:
|
||||
Status InitEnv();
|
||||
Status FinalizeEnv();
|
||||
|
|
|
@ -42,6 +42,8 @@ class GraphCell::GraphImpl {
|
|||
virtual std::vector<MSTensor> GetInputs() = 0;
|
||||
virtual std::vector<MSTensor> GetOutputs() = 0;
|
||||
|
||||
virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<Graph> graph_;
|
||||
std::shared_ptr<Context> graph_context_;
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "include/api/context.h"
|
||||
#include "cxx_api/factory.h"
|
||||
#include "cxx_api/graph/acl/acl_env_guard.h"
|
||||
#include "acl/acl_base.h"
|
||||
#include "cxx_api/acl_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
Status AclModel::Build() {
|
||||
|
@ -112,7 +112,7 @@ Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<s
|
|||
|
||||
if (model_context_ == nullptr) {
|
||||
model_context_ = std::make_shared<Context>();
|
||||
model_context_->MutableDeviceInfo().emplace_back(std::make_shared<Ascend310DeviceInfo>());
|
||||
model_context_->MutableDeviceInfo().emplace_back(std::make_shared<AscendDeviceInfo>());
|
||||
}
|
||||
|
||||
std::string input_shape_option;
|
||||
|
@ -139,7 +139,7 @@ Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<s
|
|||
MS_LOG(ERROR) << "Invalid model context, only single device info is supported.";
|
||||
return kMCInvalidArgs;
|
||||
}
|
||||
auto ascend310_info = device_infos[0]->Cast<Ascend310DeviceInfo>();
|
||||
auto ascend310_info = device_infos[0]->Cast<AscendDeviceInfo>();
|
||||
MS_EXCEPTION_IF_NULL(ascend310_info);
|
||||
ascend310_info->SetInputShape(input_shape_option);
|
||||
auto graph_cell_bak = std::move(graph_cell_);
|
||||
|
@ -163,16 +163,15 @@ std::vector<MSTensor> AclModel::GetOutputs() {
|
|||
return graph_cell_->GetOutputs();
|
||||
}
|
||||
|
||||
bool AclModel::CheckModelSupport(enum ModelType model_type) {
|
||||
const char *soc_name_c = aclrtGetSocName();
|
||||
if (soc_name_c == nullptr) {
|
||||
return false;
|
||||
}
|
||||
std::string soc_name(soc_name_c);
|
||||
if (soc_name.find("910") != std::string::npos) {
|
||||
bool AclModel::CheckDeviceSupport(mindspore::DeviceType device_type) {
|
||||
// for Ascend, only support kAscend and kAscend310
|
||||
if (device_type != kAscend && device_type != kAscend310) {
|
||||
return false;
|
||||
}
|
||||
return IsAscendNo910Soc();
|
||||
}
|
||||
|
||||
bool AclModel::CheckModelSupport(enum ModelType model_type) {
|
||||
static const std::set<ModelType> kSupportedModelMap = {kMindIR, kOM};
|
||||
auto iter = kSupportedModelMap.find(model_type);
|
||||
if (iter == kSupportedModelMap.end()) {
|
||||
|
|
|
@ -43,6 +43,7 @@ class AclModel : public ModelImpl {
|
|||
std::vector<MSTensor> GetInputs() override;
|
||||
std::vector<MSTensor> GetOutputs() override;
|
||||
|
||||
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
|
||||
bool CheckModelSupport(enum ModelType model_type) override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
#include "cxx_api/model/acl/acl_vm/acl_vm.h"
|
||||
|
||||
namespace mindspore {
|
||||
API_FACTORY_REG(ModelImpl, Ascend310, AclModelMulti);
|
||||
API_FACTORY_REG(ModelImpl, AclModelMulti);
|
||||
|
||||
namespace {
|
||||
std::map<DataType, size_t> kDtypeMap = {
|
||||
|
|
|
@ -33,7 +33,7 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
|
|||
if (device_infos.size() != 1) {
|
||||
return;
|
||||
}
|
||||
auto ascend310_info = device_infos[0]->Cast<Ascend310DeviceInfo>();
|
||||
auto ascend310_info = device_infos[0]->Cast<AscendDeviceInfo>();
|
||||
if (ascend310_info == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -20,19 +20,6 @@
|
|||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
std::string GetDeviceTypeString(enum DeviceType type) {
|
||||
static const std::map<enum DeviceType, std::string> kDeviceTypeStrs = {
|
||||
{kCPU, "CPU"}, {kGPU, "GPU"}, {kKirinNPU, "KirinGPU"}, {kAscend910, "Ascend910"}, {kAscend310, "Ascend310"},
|
||||
};
|
||||
auto iter = kDeviceTypeStrs.find(type);
|
||||
if (iter != kDeviceTypeStrs.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
return "InvalidDeviceType" + std::to_string(static_cast<int>(type));
|
||||
}
|
||||
} // namespace
|
||||
Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_context,
|
||||
const std::shared_ptr<TrainCfg> &) {
|
||||
if (graph_cell.GetGraph() == nullptr) {
|
||||
|
@ -50,7 +37,7 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_
|
|||
return kMCInvalidInput;
|
||||
}
|
||||
|
||||
std::string device_target = GetDeviceTypeString(device_info[0]->GetDeviceType());
|
||||
auto device_target = device_info[0]->GetDeviceType();
|
||||
impl_ = Factory<ModelImpl>::Instance().Create(device_target);
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Create session type " << device_target << " failed";
|
||||
|
@ -175,16 +162,10 @@ Model::Model() : impl_(nullptr) {}
|
|||
Model::~Model() {}
|
||||
|
||||
bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) {
|
||||
std::string device_type_str = GetDeviceTypeString(device_type);
|
||||
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type_str)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto check_model = Factory<ModelImpl>::Instance().Create(device_type_str);
|
||||
auto check_model = Factory<ModelImpl>::Instance().Create(device_type);
|
||||
if (check_model == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return check_model->CheckModelSupport(model_type);
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,8 @@ class ModelImpl {
|
|||
virtual std::vector<MSTensor> GetInputs() = 0;
|
||||
virtual std::vector<MSTensor> GetOutputs() = 0;
|
||||
|
||||
virtual bool CheckModelSupport(enum ModelType model_type) { return false; }
|
||||
virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0;
|
||||
virtual bool CheckModelSupport(enum ModelType model_type) = 0;
|
||||
|
||||
virtual Status Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
||||
|
||||
|
|
|
@ -20,14 +20,13 @@
|
|||
#include "include/api/context.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "cxx_api/factory.h"
|
||||
#if ENABLE_D
|
||||
#include "cxx_api/acl_utils.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
// mindspore-serving check current package for version check with ModelImpl factory.
|
||||
#if ENABLE_D
|
||||
API_FACTORY_REG(ModelImpl, Ascend910, MsModel);
|
||||
#elif ENABLE_GPU
|
||||
API_FACTORY_REG(ModelImpl, GPU, MsModel);
|
||||
#endif
|
||||
API_FACTORY_REG(ModelImpl, MsModel);
|
||||
|
||||
static std::string GenerateShapeKey(const std::vector<std::vector<int64_t>> &dims) {
|
||||
std::string shape_key;
|
||||
|
@ -171,18 +170,23 @@ uint32_t MsModel::GetDeviceID() const {
|
|||
return 0;
|
||||
}
|
||||
|
||||
bool MsModel::CheckModelSupport(enum ModelType model_type) {
|
||||
bool MsModel::CheckDeviceSupport(enum DeviceType device_type) {
|
||||
#if ENABLE_D
|
||||
const char *soc_name_c = aclrtGetSocName();
|
||||
if (soc_name_c == nullptr) {
|
||||
// for Ascend, only support kAscend or kAscend910
|
||||
if (device_type != kAscend && device_type != kAscend910) {
|
||||
return false;
|
||||
}
|
||||
std::string soc_name(soc_name_c);
|
||||
if (soc_name.find("910") == std::string::npos) {
|
||||
return IsAscend910Soc();
|
||||
#else
|
||||
// otherwise, only support GPU
|
||||
if (device_type != kGPU) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool MsModel::CheckModelSupport(mindspore::ModelType model_type) {
|
||||
static const std::set<ModelType> kSupportedModelMap = {kMindIR};
|
||||
auto iter = kSupportedModelMap.find(model_type);
|
||||
if (iter == kSupportedModelMap.end()) {
|
||||
|
|
|
@ -44,6 +44,7 @@ class MsModel : public ModelImpl {
|
|||
std::vector<MSTensor> GetInputs() override;
|
||||
std::vector<MSTensor> GetOutputs() override;
|
||||
|
||||
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
|
||||
bool CheckModelSupport(enum ModelType model_type) override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
#define MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
|
||||
|
@ -25,212 +25,431 @@
|
|||
|
||||
namespace mindspore {
|
||||
enum DeviceType {
|
||||
kCPU = 0,
|
||||
kGPU,
|
||||
kKirinNPU,
|
||||
kAscend910,
|
||||
kAscend310,
|
||||
// add new type here
|
||||
kInvalidDeviceType = 100,
|
||||
kCPU = 0,
|
||||
kGPU,
|
||||
kKirinNPU,
|
||||
kAscend,
|
||||
kAscend910,
|
||||
kAscend310,
|
||||
// add new type here
|
||||
kInvalidDeviceType = 100,
|
||||
};
|
||||
|
||||
class Allocator;
|
||||
class Delegate;
|
||||
class DeviceInfoContext;
|
||||
|
||||
/// \brief Context is used to store environment variables during execution.
|
||||
class MS_API Context {
|
||||
public:
|
||||
Context();
|
||||
~Context() = default;
|
||||
public:
|
||||
struct Data;
|
||||
Context();
|
||||
~Context() = default;
|
||||
|
||||
void SetThreadNum(int32_t thread_num);
|
||||
int32_t GetThreadNum() const;
|
||||
/// \brief Set the number of threads at runtime. Only valid for Lite.
|
||||
///
|
||||
/// \param[in] thread_num the number of threads at runtime.
|
||||
void SetThreadNum(int32_t thread_num);
|
||||
|
||||
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
|
||||
std::shared_ptr<Allocator> GetAllocator() const;
|
||||
/// \brief Get the current thread number setting. Only valid for Lite.
|
||||
///
|
||||
/// \return The current thread number setting.
|
||||
int32_t GetThreadNum() const;
|
||||
|
||||
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
|
||||
/// \brief Set the thread affinity to CPU cores. Only valid for Lite.
|
||||
///
|
||||
/// \param[in] mode: 0: no affinities, 1: big cores first, 2: little cores first
|
||||
void SetThreadAffinity(int mode);
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
/// \brief Get the thread affinity of CPU cores. Only valid for Lite.
|
||||
///
|
||||
/// \return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first
|
||||
int GetThreadAffinityMode() const;
|
||||
|
||||
/// \brief Set the thread lists to CPU cores. Only valid for Lite.
|
||||
///
|
||||
/// \note If core_list and mode are set by SetThreadAffinity at the same time, the core_list is effective, but the
|
||||
/// mode is not effective.
|
||||
///
|
||||
/// \param[in] core_list: a vector of thread core lists.
|
||||
void SetThreadAffinity(const std::vector<int> &core_list);
|
||||
|
||||
/// \brief Get the thread lists of CPU cores. Only valid for Lite.
|
||||
///
|
||||
/// \return core_list: a vector of thread core lists.
|
||||
std::vector<int32_t> GetThreadAffinityCoreList() const;
|
||||
|
||||
/// \brief Set the status whether to perform model inference or training in parallel. Only valid for Lite.
|
||||
///
|
||||
/// \param[in] is_parallel: true, parallel; false, not in parallel.
|
||||
void SetEnableParallel(bool is_parallel);
|
||||
|
||||
/// \brief Get the status whether to perform model inference or training in parallel. Only valid for Lite.
|
||||
///
|
||||
/// \return Bool value that indicates whether in parallel.
|
||||
bool GetEnableParallel() const;
|
||||
|
||||
/// \brief Set Delegate to access third-party AI framework. Only valid for Lite.
|
||||
///
|
||||
/// \param[in] Pointer to the custom delegate.
|
||||
void SetDelegate(const std::shared_ptr<Delegate> &delegate);
|
||||
|
||||
/// \brief Get the delegate of the third-party AI framework. Only valid for Lite.
|
||||
///
|
||||
/// \return Pointer to the custom delegate.
|
||||
std::shared_ptr<Delegate> GetDelegate() const;
|
||||
|
||||
/// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports
|
||||
/// heterogeneous scenarios with multiple members in the vector.
|
||||
///
|
||||
/// \return Mutable reference of DeviceInfoContext vector in this context.
|
||||
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
|
||||
|
||||
private:
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief DeviceInfoContext defines different device contexts.
|
||||
class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
|
||||
public:
|
||||
struct Data;
|
||||
public:
|
||||
struct Data;
|
||||
|
||||
DeviceInfoContext();
|
||||
virtual ~DeviceInfoContext() = default;
|
||||
virtual enum DeviceType GetDeviceType() const = 0;
|
||||
DeviceInfoContext();
|
||||
virtual ~DeviceInfoContext() = default;
|
||||
|
||||
template <class T>
|
||||
std::shared_ptr<T> Cast() {
|
||||
static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
|
||||
if (GetDeviceType() != T().GetDeviceType()) {
|
||||
return nullptr;
|
||||
}
|
||||
/// \brief Get the type of this DeviceInfoContext.
|
||||
///
|
||||
/// \return Type of this DeviceInfoContext.
|
||||
virtual enum DeviceType GetDeviceType() const = 0;
|
||||
|
||||
return std::static_pointer_cast<T>(shared_from_this());
|
||||
}
|
||||
/// \brief A similar function to RTTI is provided when the -fno-rtti compilation option is turned on, which converts
|
||||
/// DeviceInfoContext to a shared pointer of type T, and returns nullptr if the conversion fails.
|
||||
///
|
||||
/// \param T Type
|
||||
/// \return A pointer of type T after conversion. If the conversion fails, it will be nullptr.
|
||||
template <class T>
|
||||
std::shared_ptr<T> Cast() {
|
||||
static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
|
||||
if (GetDeviceType() != T().GetDeviceType()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<Data> data_;
|
||||
return std::static_pointer_cast<T>(shared_from_this());
|
||||
}
|
||||
/// \brief obtain provider's name
|
||||
///
|
||||
/// \return provider's name.
|
||||
std::string GetProvider() const;
|
||||
/// \brief set provider's name.
|
||||
///
|
||||
/// \param[in] provider define the provider's name.
|
||||
|
||||
void SetProvider(const std::string &provider);
|
||||
/// \brief obtain provider's device type.
|
||||
///
|
||||
/// \return provider's device type.
|
||||
|
||||
std::string GetProviderDevice() const;
|
||||
/// \brief set provider's device type.
|
||||
///
|
||||
/// \param[in] device define the provider's device type.EG: CPU.
|
||||
void SetProviderDevice(const std::string &device);
|
||||
|
||||
/// \brief set memory allocator.
|
||||
///
|
||||
/// \param[in] allocator define the memory allocator which can be defined by user.
|
||||
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
|
||||
|
||||
/// \brief obtain memory allocator.
|
||||
///
|
||||
/// \return memory allocator.
|
||||
std::shared_ptr<Allocator> GetAllocator() const;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the CPU. This option is only valid
|
||||
/// for MindSpore Lite.
|
||||
class MS_API CPUDeviceInfo : public DeviceInfoContext {
|
||||
public:
|
||||
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
|
||||
public:
|
||||
/// \brief Get the type of this DeviceInfoContext.
|
||||
///
|
||||
/// \return Type of this DeviceInfoContext.
|
||||
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
|
||||
|
||||
/// \brief Set the thread affinity to CPU cores.
|
||||
///
|
||||
/// \param mode: 0: no affinities, 1: big cores first, 2: little cores first
|
||||
void SetThreadAffinity(int mode);
|
||||
int GetThreadAffinity() const;
|
||||
void SetEnableFP16(bool is_fp16);
|
||||
bool GetEnableFP16() const;
|
||||
/// \brief Set enables to perform the float16 inference
|
||||
///
|
||||
/// \param[in] is_fp16 Enable float16 inference or not.
|
||||
void SetEnableFP16(bool is_fp16);
|
||||
|
||||
/// \brief Get enables to perform the float16 inference
|
||||
///
|
||||
/// \return Whether enable float16 inference.
|
||||
bool GetEnableFP16() const;
|
||||
};
|
||||
|
||||
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the NPU. This option is only valid
|
||||
/// for MindSpore Lite.
|
||||
class MS_API KirinNPUDeviceInfo : public DeviceInfoContext {
|
||||
public:
|
||||
enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
|
||||
public:
|
||||
/// \brief Get the type of this DeviceInfoContext.
|
||||
///
|
||||
/// \return Type of this DeviceInfoContext.
|
||||
enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
|
||||
|
||||
void SetFrequency(int frequency);
|
||||
int GetFrequency() const;
|
||||
/// \brief Set the NPU frequency.
|
||||
///
|
||||
/// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme
|
||||
/// performance), default as 3.
|
||||
void SetFrequency(int frequency);
|
||||
|
||||
/// \brief Get the NPU frequency.
|
||||
///
|
||||
/// \return NPU frequency
|
||||
int GetFrequency() const;
|
||||
};
|
||||
|
||||
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the GPU.
|
||||
class MS_API GPUDeviceInfo : public DeviceInfoContext {
|
||||
public:
|
||||
enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; };
|
||||
public:
|
||||
/// \brief Get the type of this DeviceInfoContext.
|
||||
///
|
||||
/// \return Type of this DeviceInfoContext.
|
||||
enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; };
|
||||
|
||||
void SetDeviceID(uint32_t device_id);
|
||||
uint32_t GetDeviceID() const;
|
||||
/// \brief Set device id.
|
||||
///
|
||||
/// \param[in] device_id The device id.
|
||||
void SetDeviceID(uint32_t device_id);
|
||||
|
||||
void SetGpuTrtInferMode(bool gpu_trt_infer_mode);
|
||||
bool GetGpuTrtInferMode() const;
|
||||
/// \brief Get the device id.
|
||||
///
|
||||
/// \return The device id.
|
||||
uint32_t GetDeviceID() const;
|
||||
|
||||
void SetEnableFP16(bool is_fp16);
|
||||
bool GetEnableFP16() const;
|
||||
/// \brief Get the distribution rank id.
|
||||
///
|
||||
/// \return The device id.
|
||||
int GetRankID() const;
|
||||
|
||||
/// \brief Get the distribution group size.
|
||||
///
|
||||
/// \return The device id.
|
||||
int GetGroupSize() const;
|
||||
|
||||
/// \brief Set the precision mode.
|
||||
///
|
||||
/// \param[in] precision_mode Optional "origin", "fp16". "origin" is set as default.
|
||||
inline void SetPrecisionMode(const std::string &precision_mode);
|
||||
|
||||
/// \brief Get the precision mode.
|
||||
///
|
||||
/// \return The precision mode.
|
||||
inline std::string GetPrecisionMode() const;
|
||||
|
||||
/// \brief Set enables to perform the float16 inference
|
||||
///
|
||||
/// \param[in] is_fp16 Enable float16 inference or not.
|
||||
void SetEnableFP16(bool is_fp16);
|
||||
|
||||
/// \brief Get enables to perform the float16 inference
|
||||
///
|
||||
/// \return Whether enable float16 inference.
|
||||
bool GetEnableFP16() const;
|
||||
|
||||
private:
|
||||
void SetPrecisionMode(const std::vector<char> &precision_mode);
|
||||
std::vector<char> GetPrecisionModeChar() const;
|
||||
};
|
||||
|
||||
class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
|
||||
public:
|
||||
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
|
||||
void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
||||
SetPrecisionMode(StringToChar(precision_mode));
|
||||
}
|
||||
std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
||||
|
||||
void SetDeviceID(uint32_t device_id);
|
||||
uint32_t GetDeviceID() const;
|
||||
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is
|
||||
/// invalid for MindSpore Lite.
|
||||
class MS_API AscendDeviceInfo : public DeviceInfoContext {
|
||||
public:
|
||||
/// \brief Get the type of this DeviceInfoContext.
|
||||
///
|
||||
/// \return Type of this DeviceInfoContext.
|
||||
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; };
|
||||
|
||||
/// \brief Set device id.
|
||||
///
|
||||
/// \param[in] device_id The device id.
|
||||
void SetDeviceID(uint32_t device_id);
|
||||
|
||||
/// \brief Get the device id.
|
||||
///
|
||||
/// \return The device id.
|
||||
uint32_t GetDeviceID() const;
|
||||
|
||||
/// \brief Set AIPP configuration file path.
|
||||
///
|
||||
/// \param[in] cfg_path AIPP configuration file path.
|
||||
inline void SetInsertOpConfigPath(const std::string &cfg_path);
|
||||
|
||||
/// \brief Get AIPP configuration file path.
|
||||
///
|
||||
/// \return AIPP configuration file path.
|
||||
inline std::string GetInsertOpConfigPath() const;
|
||||
|
||||
/// \brief Set format of model inputs.
|
||||
///
|
||||
/// \param[in] format Optional "NCHW", "NHWC", etc.
|
||||
inline void SetInputFormat(const std::string &format);
|
||||
|
||||
/// \brief Get format of model inputs.
|
||||
///
|
||||
/// \return The format of model inputs.
|
||||
inline std::string GetInputFormat() const;
|
||||
|
||||
/// \brief Set shape of model inputs.
|
||||
///
|
||||
/// \param[in] shape e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1".
|
||||
inline void SetInputShape(const std::string &shape);
|
||||
|
||||
/// \brief Get shape of model inputs.
|
||||
///
|
||||
/// \return The shape of model inputs.
|
||||
inline std::string GetInputShape() const;
|
||||
|
||||
/// \brief Set shape of model inputs.
|
||||
///
|
||||
/// \param[in] shape e.g. {{1, {1,2,3,4}}, {2, {4,3,2,1}}} means the first input shape 1,2,3,4 and the second input
|
||||
/// shape 4,3,2,1.
|
||||
void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
|
||||
|
||||
/// \brief Get shape of model inputs.
|
||||
///
|
||||
/// \return The shape of model inputs.
|
||||
std::map<int, std::vector<int>> GetInputShapeMap() const;
|
||||
|
||||
void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
|
||||
inline std::string GetDynamicBatchSize() const;
|
||||
|
||||
/// \brief Set the dynamic image size of model inputs.
|
||||
///
|
||||
/// \param[in] image size hw e.g. "66,88;32,64" means h1:66,w1:88; h2:32,w2:64.
|
||||
inline void SetDynamicImageSize(const std::string &dynamic_image_size);
|
||||
|
||||
/// \brief Get dynamic image size of model inputs.
|
||||
///
|
||||
/// \return The image size of model inputs.
|
||||
inline std::string GetDynamicImageSize() const;
|
||||
|
||||
/// \brief Set type of model outputs.
|
||||
///
|
||||
/// \param[in] output_type FP32, UINT8 or FP16, default as FP32.
|
||||
void SetOutputType(enum DataType output_type);
|
||||
|
||||
/// \brief Get type of model outputs.
|
||||
///
|
||||
/// \return The set type of model outputs.
|
||||
enum DataType GetOutputType() const;
|
||||
|
||||
/// \brief Set precision mode of model.
|
||||
///
|
||||
/// \param[in] precision_mode Optional "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" and
|
||||
/// "allow_mix_precision", "force_fp16" is set as default
|
||||
inline void SetPrecisionMode(const std::string &precision_mode);
|
||||
|
||||
/// \brief Get precision mode of model.
|
||||
///
|
||||
/// \return The set type of model outputs
|
||||
inline std::string GetPrecisionMode() const;
|
||||
|
||||
/// \brief Set op select implementation mode.
|
||||
///
|
||||
/// \param[in] op_select_impl_mode Optional "high_performance" and "high_precision", "high_performance" is set as
|
||||
/// default.
|
||||
inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
|
||||
|
||||
/// \brief Get op select implementation mode.
|
||||
///
|
||||
/// \return The set op select implementation mode.
|
||||
inline std::string GetOpSelectImplMode() const;
|
||||
|
||||
inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
|
||||
inline std::string GetFusionSwitchConfigPath() const;
|
||||
|
||||
// Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize"
|
||||
inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode);
|
||||
inline std::string GetBufferOptimizeMode() const;
|
||||
|
||||
private:
|
||||
void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
|
||||
std::vector<char> GetInsertOpConfigPathChar() const;
|
||||
|
||||
void SetInputFormat(const std::vector<char> &format);
|
||||
std::vector<char> GetInputFormatChar() const;
|
||||
|
||||
void SetInputShape(const std::vector<char> &shape);
|
||||
std::vector<char> GetInputShapeChar() const;
|
||||
|
||||
std::vector<char> GetDynamicBatchSizeChar() const;
|
||||
|
||||
void SetDynamicImageSize(const std::vector<char> &dynamic_image_size);
|
||||
std::vector<char> GetDynamicImageSizeChar() const;
|
||||
|
||||
void SetPrecisionMode(const std::vector<char> &precision_mode);
|
||||
std::vector<char> GetPrecisionModeChar() const;
|
||||
|
||||
void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
|
||||
std::vector<char> GetOpSelectImplModeChar() const;
|
||||
|
||||
void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
|
||||
std::vector<char> GetFusionSwitchConfigPathChar() const;
|
||||
|
||||
void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode);
|
||||
std::vector<char> GetBufferOptimizeModeChar() const;
|
||||
};
|
||||
|
||||
class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
|
||||
public:
|
||||
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };
|
||||
using Ascend310DeviceInfo = AscendDeviceInfo;
|
||||
using Ascend910DeviceInfo = AscendDeviceInfo;
|
||||
|
||||
void SetDeviceID(uint32_t device_id);
|
||||
uint32_t GetDeviceID() const;
|
||||
|
||||
inline void SetDumpConfigPath(const std::string &cfg_path);
|
||||
inline std::string GetDumpConfigPath() const;
|
||||
|
||||
// aipp config file
|
||||
inline void SetInsertOpConfigPath(const std::string &cfg_path);
|
||||
inline std::string GetInsertOpConfigPath() const;
|
||||
|
||||
// nchw or nhwc
|
||||
inline void SetInputFormat(const std::string &format);
|
||||
inline std::string GetInputFormat() const;
|
||||
|
||||
// Mandatory while dynamic batch: e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1"
|
||||
inline void SetInputShape(const std::string &shape);
|
||||
inline std::string GetInputShape() const;
|
||||
|
||||
void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
|
||||
std::map<int, std::vector<int>> GetInputShapeMap() const;
|
||||
|
||||
void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
|
||||
inline std::string GetDynamicBatchSize() const;
|
||||
|
||||
// FP32, UINT8 or FP16, default as FP32
|
||||
void SetOutputType(enum DataType output_type);
|
||||
enum DataType GetOutputType() const;
|
||||
|
||||
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype", default as "force_fp16"
|
||||
inline void SetPrecisionMode(const std::string &precision_mode);
|
||||
inline std::string GetPrecisionMode() const;
|
||||
|
||||
// Optional "high_performance" and "high_precision", "high_performance" is set as default
|
||||
inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
|
||||
inline std::string GetOpSelectImplMode() const;
|
||||
|
||||
inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
|
||||
inline std::string GetFusionSwitchConfigPath() const;
|
||||
|
||||
// Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize"
|
||||
inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode);
|
||||
inline std::string GetBufferOptimizeMode() const;
|
||||
|
||||
private:
|
||||
void SetDumpConfigPath(const std::vector<char> &cfg_path);
|
||||
std::vector<char> GetDumpConfigPathChar() const;
|
||||
|
||||
void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
|
||||
std::vector<char> GetInsertOpConfigPathChar() const;
|
||||
|
||||
void SetInputFormat(const std::vector<char> &format);
|
||||
std::vector<char> GetInputFormatChar() const;
|
||||
|
||||
void SetInputShape(const std::vector<char> &shape);
|
||||
std::vector<char> GetInputShapeChar() const;
|
||||
|
||||
std::vector<char> GetDynamicBatchSizeChar() const;
|
||||
|
||||
void SetPrecisionMode(const std::vector<char> &precision_mode);
|
||||
std::vector<char> GetPrecisionModeChar() const;
|
||||
|
||||
void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
|
||||
std::vector<char> GetOpSelectImplModeChar() const;
|
||||
|
||||
void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
|
||||
std::vector<char> GetFusionSwitchConfigPathChar() const;
|
||||
|
||||
void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode);
|
||||
std::vector<char> GetBufferOptimizeModeChar() const;
|
||||
};
|
||||
|
||||
void Ascend310DeviceInfo::SetDumpConfigPath(const std::string &cfg_path) { SetDumpConfigPath(StringToChar(cfg_path)); }
|
||||
std::string Ascend310DeviceInfo::GetDumpConfigPath() const { return CharToString(GetDumpConfigPathChar()); }
|
||||
|
||||
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
|
||||
SetInsertOpConfigPath(StringToChar(cfg_path));
|
||||
void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
|
||||
SetInsertOpConfigPath(StringToChar(cfg_path));
|
||||
}
|
||||
std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
|
||||
std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
|
||||
|
||||
void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
|
||||
std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
|
||||
void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
|
||||
std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
|
||||
|
||||
void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
|
||||
std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
|
||||
void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
|
||||
std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
|
||||
|
||||
std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
|
||||
std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
|
||||
|
||||
void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
||||
SetPrecisionMode(StringToChar(precision_mode));
|
||||
}
|
||||
std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
||||
|
||||
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
|
||||
SetOpSelectImplMode(StringToChar(op_select_impl_mode));
|
||||
}
|
||||
std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
|
||||
|
||||
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
|
||||
SetFusionSwitchConfigPath(StringToChar(cfg_path));
|
||||
}
|
||||
std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const {
|
||||
return CharToString(GetFusionSwitchConfigPathChar());
|
||||
void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) {
|
||||
SetDynamicImageSize(StringToChar(dynamic_image_size));
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
|
||||
SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
|
||||
std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); }
|
||||
|
||||
void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
||||
SetPrecisionMode(StringToChar(precision_mode));
|
||||
}
|
||||
std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
|
||||
std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
||||
|
||||
void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
|
||||
SetOpSelectImplMode(StringToChar(op_select_impl_mode));
|
||||
}
|
||||
std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
|
||||
|
||||
void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
|
||||
SetFusionSwitchConfigPath(StringToChar(cfg_path));
|
||||
}
|
||||
std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const {
|
||||
return CharToString(GetFusionSwitchConfigPathChar());
|
||||
}
|
||||
|
||||
void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
|
||||
SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
|
||||
}
|
||||
std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
|
|
|
@ -317,13 +317,7 @@ std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
|
|||
return ret;
|
||||
}
|
||||
|
||||
void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||
uint32_t Ascend910DeviceInfo::GetDeviceID() const {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||
void AscendDeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
|
@ -331,7 +325,7 @@ void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
|
|||
data_->params[kModelOptionAscend310DeviceID] = device_id;
|
||||
}
|
||||
|
||||
uint32_t Ascend310DeviceInfo::GetDeviceID() const {
|
||||
uint32_t AscendDeviceInfo::GetDeviceID() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return 0;
|
||||
|
@ -339,14 +333,14 @@ uint32_t Ascend310DeviceInfo::GetDeviceID() const {
|
|||
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
|
||||
void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
|
||||
}
|
||||
std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
|
@ -355,7 +349,7 @@ std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
|
|||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
|
||||
void AscendDeviceInfo::SetInputFormat(const std::vector<char> &format) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
|
@ -363,7 +357,7 @@ void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
|
|||
data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
|
||||
}
|
||||
|
||||
std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetInputFormatChar() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
|
@ -372,14 +366,14 @@ std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
|
|||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) {
|
||||
void AscendDeviceInfo::SetInputShape(const std::vector<char> &shape) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
|
||||
}
|
||||
std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetInputShapeChar() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
|
@ -388,7 +382,7 @@ std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
|
|||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
|
||||
void AscendDeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
|
@ -403,7 +397,7 @@ void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic
|
|||
data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
|
||||
}
|
||||
|
||||
std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
|
@ -412,7 +406,7 @@ std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
|
|||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) {
|
||||
void AscendDeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
|
@ -420,7 +414,7 @@ void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_i
|
|||
data_->params[kModelOptionAscend310DynamicImageSize] = CharToString(dynamic_image_size);
|
||||
}
|
||||
|
||||
std::vector<char> Ascend310DeviceInfo::GetDynamicImageSizeChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
|
@ -429,7 +423,7 @@ std::vector<char> Ascend310DeviceInfo::GetDynamicImageSizeChar() const {
|
|||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
||||
void AscendDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
|
@ -437,7 +431,7 @@ void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mo
|
|||
data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
|
||||
}
|
||||
|
||||
std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
|
@ -446,7 +440,7 @@ std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
|
|||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
|
||||
void AscendDeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
|
@ -454,7 +448,7 @@ void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select
|
|||
data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
|
||||
}
|
||||
|
||||
std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
|
@ -463,14 +457,14 @@ std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
|
|||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
|
||||
void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
|
||||
}
|
||||
std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
|
@ -479,7 +473,7 @@ std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
|
|||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
|
||||
void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
|
@ -487,7 +481,7 @@ void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>>
|
|||
data_->params[kModelOptionAscend310InputShapeMap] = shape;
|
||||
}
|
||||
|
||||
std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
|
||||
std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::map<int, std::vector<int>>();
|
||||
|
@ -495,7 +489,7 @@ std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
|
|||
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
|
||||
void AscendDeviceInfo::SetOutputType(enum DataType output_type) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
|
@ -503,7 +497,7 @@ void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
|
|||
data_->params[kModelOptionAscend310OutputType] = output_type;
|
||||
}
|
||||
|
||||
enum DataType Ascend310DeviceInfo::GetOutputType() const {
|
||||
enum DataType AscendDeviceInfo::GetOutputType() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return DataType::kTypeUnknown;
|
||||
|
@ -511,7 +505,7 @@ enum DataType Ascend310DeviceInfo::GetOutputType() const {
|
|||
return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
|
||||
void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
|
@ -519,7 +513,7 @@ void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_
|
|||
data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
|
||||
}
|
||||
|
||||
std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const {
|
||||
std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
|
|
|
@ -104,7 +104,7 @@ lite::InnerContext *ContextUtils::Convert(Context *context) {
|
|||
} else if (device->GetDeviceType() == kKirinNPU) {
|
||||
auto npu_context = device->Cast<KirinNPUDeviceInfo>();
|
||||
ret = AddNpuDevice(npu_context->GetFrequency(), inner_context.get());
|
||||
} else if (device->GetDeviceType() == kAscend310) {
|
||||
} else if (device->GetDeviceType() == kAscend) {
|
||||
ret = AddAscend310Device(inner_context.get(), device.get());
|
||||
}
|
||||
if (ret != kSuccess) {
|
||||
|
|
|
@ -71,11 +71,11 @@ std::shared_ptr<mindspore::Context> Common::ContextAutoSet() {
|
|||
auto context = std::make_shared<mindspore::Context>();
|
||||
|
||||
if (device_target_str == "Ascend310") {
|
||||
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||
auto ascend310_info = std::make_shared<mindspore::AscendDeviceInfo>();
|
||||
ascend310_info->SetDeviceID(device_id);
|
||||
context->MutableDeviceInfo().emplace_back(ascend310_info);
|
||||
} else if (device_target_str == "Ascend910") {
|
||||
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||
auto ascend310_info = std::make_shared<mindspore::AscendDeviceInfo>();
|
||||
ascend310_info->SetDeviceID(device_id);
|
||||
context->MutableDeviceInfo().emplace_back(ascend310_info);
|
||||
} else {
|
||||
|
|
|
@ -101,7 +101,7 @@ TEST_F(TestDE, TestDvpp) {
|
|||
auto context = ContextAutoSet();
|
||||
ASSERT_TRUE(context != nullptr);
|
||||
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
|
||||
ASSERT_TRUE(ascend310_info != nullptr);
|
||||
auto device_id = ascend310_info->GetDeviceID();
|
||||
|
||||
|
@ -154,7 +154,7 @@ TEST_F(TestDE, TestDvppSinkMode) {
|
|||
auto context = ContextAutoSet();
|
||||
ASSERT_TRUE(context != nullptr);
|
||||
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
|
||||
ASSERT_TRUE(ascend310_info != nullptr);
|
||||
auto device_id = ascend310_info->GetDeviceID();
|
||||
|
||||
|
@ -202,7 +202,7 @@ TEST_F(TestDE, TestDvppDecodeResizeCropNormalize) {
|
|||
auto context = ContextAutoSet();
|
||||
ASSERT_TRUE(context != nullptr);
|
||||
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
|
||||
ASSERT_TRUE(ascend310_info != nullptr);
|
||||
auto device_id = ascend310_info->GetDeviceID();
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ TEST_F(TestDynamicBatchSize, InferMindIR) {
|
|||
auto context = ContextAutoSet();
|
||||
ASSERT_TRUE(context != nullptr);
|
||||
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
|
||||
ASSERT_TRUE(ascend310_info != nullptr);
|
||||
|
||||
std::map<int, std::vector<int>> input_shape;
|
||||
|
|
|
@ -59,7 +59,7 @@ TEST_F(TestZeroCopy, TestMindIR) {
|
|||
auto context = ContextAutoSet();
|
||||
ASSERT_TRUE(context != nullptr);
|
||||
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
|
||||
ASSERT_TRUE(ascend310_info != nullptr);
|
||||
ascend310_info->SetInsertOpConfigPath(aipp_path);
|
||||
auto device_id = ascend310_info->GetDeviceID();
|
||||
|
@ -107,7 +107,7 @@ TEST_F(TestZeroCopy, TestDeviceTensor) {
|
|||
auto context = ContextAutoSet();
|
||||
ASSERT_TRUE(context != nullptr);
|
||||
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
|
||||
ASSERT_TRUE(ascend310_info != nullptr);
|
||||
ascend310_info->SetInsertOpConfigPath(aipp_path);
|
||||
auto device_id = ascend310_info->GetDeviceID();
|
||||
|
|
|
@ -27,32 +27,27 @@ TEST_F(TestCxxApiContext, test_context_device_info_cast_SUCCESS) {
|
|||
std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>();
|
||||
std::shared_ptr<DeviceInfoContext> gpu = std::make_shared<GPUDeviceInfo>();
|
||||
std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>();
|
||||
std::shared_ptr<DeviceInfoContext> ascend310 = std::make_shared<Ascend310DeviceInfo>();
|
||||
std::shared_ptr<DeviceInfoContext> ascend910 = std::make_shared<Ascend910DeviceInfo>();
|
||||
std::shared_ptr<DeviceInfoContext> ascend = std::make_shared<AscendDeviceInfo>();
|
||||
|
||||
ASSERT_TRUE(cpu->Cast<CPUDeviceInfo>() != nullptr);
|
||||
ASSERT_TRUE(gpu->Cast<GPUDeviceInfo>() != nullptr);
|
||||
ASSERT_TRUE(kirin_npu->Cast<KirinNPUDeviceInfo>() != nullptr);
|
||||
ASSERT_TRUE(ascend310->Cast<Ascend310DeviceInfo>() != nullptr);
|
||||
ASSERT_TRUE(ascend910->Cast<Ascend910DeviceInfo>() != nullptr);
|
||||
ASSERT_TRUE(ascend->Cast<AscendDeviceInfo>() != nullptr);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiContext, test_context_device_info_cast_FAILED) {
|
||||
std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>();
|
||||
std::shared_ptr<DeviceInfoContext> gpu = std::make_shared<GPUDeviceInfo>();
|
||||
std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>();
|
||||
std::shared_ptr<DeviceInfoContext> ascend310 = std::make_shared<Ascend310DeviceInfo>();
|
||||
std::shared_ptr<DeviceInfoContext> ascend910 = std::make_shared<Ascend910DeviceInfo>();
|
||||
std::shared_ptr<DeviceInfoContext> ascend = std::make_shared<AscendDeviceInfo>();
|
||||
|
||||
ASSERT_TRUE(cpu->Cast<GPUDeviceInfo>() == nullptr);
|
||||
ASSERT_TRUE(kirin_npu->Cast<GPUDeviceInfo>() == nullptr);
|
||||
ASSERT_TRUE(ascend310->Cast<GPUDeviceInfo>() == nullptr);
|
||||
ASSERT_TRUE(ascend910->Cast<GPUDeviceInfo>() == nullptr);
|
||||
ASSERT_TRUE(ascend->Cast<GPUDeviceInfo>() == nullptr);
|
||||
|
||||
ASSERT_TRUE(gpu->Cast<CPUDeviceInfo>() == nullptr);
|
||||
ASSERT_TRUE(kirin_npu->Cast<CPUDeviceInfo>() == nullptr);
|
||||
ASSERT_TRUE(ascend310->Cast<CPUDeviceInfo>() == nullptr);
|
||||
ASSERT_TRUE(ascend910->Cast<CPUDeviceInfo>() == nullptr);
|
||||
ASSERT_TRUE(ascend->Cast<CPUDeviceInfo>() == nullptr);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiContext, test_context_get_set_SUCCESS) {
|
||||
|
@ -86,7 +81,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) {
|
|||
std::string option_9_ans = "1,2,3,4,5";
|
||||
|
||||
auto context = std::make_shared<Context>();
|
||||
std::shared_ptr<Ascend310DeviceInfo> ascend310 = std::make_shared<Ascend310DeviceInfo>();
|
||||
std::shared_ptr<AscendDeviceInfo> ascend310 = std::make_shared<AscendDeviceInfo>();
|
||||
ascend310->SetInputShape(option_1);
|
||||
ascend310->SetInsertOpConfigPath(option_2);
|
||||
ascend310->SetOpSelectImplMode(option_3);
|
||||
|
@ -99,7 +94,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) {
|
|||
|
||||
context->MutableDeviceInfo().push_back(ascend310);
|
||||
ASSERT_EQ(context->MutableDeviceInfo().size(), 1);
|
||||
auto ctx = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
||||
auto ctx = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
|
||||
ASSERT_TRUE(ctx != nullptr);
|
||||
ASSERT_EQ(ascend310->GetInputShape(), option_1);
|
||||
ASSERT_EQ(ascend310->GetInsertOpConfigPath(), option_2);
|
||||
|
@ -113,7 +108,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) {
|
|||
}
|
||||
|
||||
TEST_F(TestCxxApiContext, test_context_ascend310_context_default_value_SUCCESS) {
|
||||
auto ctx = std::make_shared<Ascend310DeviceInfo>();
|
||||
auto ctx = std::make_shared<AscendDeviceInfo>();
|
||||
ASSERT_EQ(ctx->GetOpSelectImplMode(), "");
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue