C++ inference interface Ascend310DeviceInfo+Ascend910DeviceInfo->AscendDeviceInfo

This commit is contained in:
xuyongfei 2021-12-09 15:20:31 +08:00
parent 87674cf3bd
commit 0a0419498f
28 changed files with 632 additions and 374 deletions

View File

@ -28,6 +28,7 @@ enum DeviceType {
kCPU = 0,
kGPU,
kKirinNPU,
kAscend,
kAscend910,
kAscend310,
// add new type here
@ -287,34 +288,14 @@ void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
}
std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend910. This option is
/// invalid for MindSpore Lite.
class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
public:
/// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
/// \brief Set device id.
///
/// \param[in] device_id The device id.
void SetDeviceID(uint32_t device_id);
/// \brief Get the device id.
///
/// \return The device id.
uint32_t GetDeviceID() const;
};
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is
/// invalid for MindSpore Lite.
class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
class MS_API AscendDeviceInfo : public DeviceInfoContext {
public:
/// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; };
/// \brief Set device id.
///
@ -447,45 +428,48 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
std::vector<char> GetBufferOptimizeModeChar() const;
};
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
using Ascend310DeviceInfo = AscendDeviceInfo;
using Ascend910DeviceInfo = AscendDeviceInfo;
void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
SetInsertOpConfigPath(StringToChar(cfg_path));
}
std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
void Ascend310DeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) {
void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) {
SetDynamicImageSize(StringToChar(dynamic_image_size));
}
std::string Ascend310DeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); }
std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); }
void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
SetPrecisionMode(StringToChar(precision_mode));
}
std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
SetOpSelectImplMode(StringToChar(op_select_impl_mode));
}
std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
SetFusionSwitchConfigPath(StringToChar(cfg_path));
}
std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const {
std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const {
return CharToString(GetFusionSwitchConfigPathChar());
}
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
}
std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H

View File

@ -193,7 +193,7 @@ class MS_API Model {
/// \brief Inference model.
///
/// \param[in] device_type Device typeoptions are kGPU, kAscend910, etc.
/// \param[in] device_type Device typeoptions are kGPU, kAscend, kAscend910, etc.
/// \param[in] model_type The type of model file, options are ModelType::kMindIR, ModelType::kOM.
///
/// \return Is supported or not.

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H
#define MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H
#include <string>
#include "acl/acl_base.h"
namespace mindspore {
static inline bool IsAscend910Soc() {
const char *soc_name_c = aclrtGetSocName();
if (soc_name_c == nullptr) {
return false;
}
std::string soc_name(soc_name_c);
if (soc_name.find("910") == std::string::npos) {
return false;
}
return true;
}
static inline bool IsAscendNo910Soc() {
const char *soc_name_c = aclrtGetSocName();
if (soc_name_c == nullptr) {
return false;
}
std::string soc_name(soc_name_c);
if (soc_name.find("910") != std::string::npos) {
return false;
}
return true;
}
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H

View File

@ -175,55 +175,46 @@ std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
return StringToChar(ref);
}
void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend910DeviceID] = device_id;
}
uint32_t Ascend910DeviceInfo::GetDeviceID() const {
MS_EXCEPTION_IF_NULL(data_);
return GetValue<uint32_t>(data_, kModelOptionAscend910DeviceID);
}
void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
void AscendDeviceInfo::SetDeviceID(uint32_t device_id) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310DeviceID] = device_id;
}
uint32_t Ascend310DeviceInfo::GetDeviceID() const {
uint32_t AscendDeviceInfo::GetDeviceID() const {
MS_EXCEPTION_IF_NULL(data_);
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
}
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
}
std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InsertOpCfgPath);
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
void AscendDeviceInfo::SetInputFormat(const std::vector<char> &format) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
}
std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
std::vector<char> AscendDeviceInfo::GetInputFormatChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputFormat);
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) {
void AscendDeviceInfo::SetInputShape(const std::vector<char> &shape) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
}
std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
std::vector<char> AscendDeviceInfo::GetInputShapeChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputShape);
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
void AscendDeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
MS_EXCEPTION_IF_NULL(data_);
std::string batchs = "";
for (size_t i = 0; i < dynamic_batch_size.size(); ++i) {
@ -234,69 +225,69 @@ void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic
}
data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
}
std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicBatchSize);
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) { return; }
void AscendDeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) { return; }
std::vector<char> Ascend310DeviceInfo::GetDynamicImageSizeChar() const { return std::vector<char>(); }
std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const { return std::vector<char>(); }
void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
void AscendDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
}
std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310PrecisionMode);
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
void AscendDeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
}
std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310OpSelectImplMode);
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
}
std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, KModelOptionAscend310FusionSwitchCfgPath);
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310InputShapeMap] = shape;
}
std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const {
MS_EXCEPTION_IF_NULL(data_);
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
}
void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
void AscendDeviceInfo::SetOutputType(enum DataType output_type) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310OutputType] = output_type;
}
enum DataType Ascend310DeviceInfo::GetOutputType() const {
enum DataType AscendDeviceInfo::GetOutputType() const {
MS_EXCEPTION_IF_NULL(data_);
return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
}
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
}
std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const {
std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310BufferOptimize);
return StringToChar(ref);

View File

@ -24,7 +24,31 @@
#include "utils/utils.h"
namespace mindspore {
inline std::string g_device_target = "Default";
inline enum DeviceType g_device_target = kInvalidDeviceType;
static inline LogStream &operator<<(LogStream &stream, DeviceType device_type) {
switch (device_type) {
case kAscend:
stream << "Ascend";
break;
case kAscend910:
stream << "Ascend910";
break;
case kAscend310:
stream << "Ascend310";
break;
case kGPU:
stream << "GPU";
break;
case kCPU:
stream << "CPU";
break;
default:
stream << "[InvalidDeviceType: " << static_cast<int>(device_type) << "]";
break;
}
return stream;
}
template <class T>
class Factory {
@ -39,32 +63,24 @@ class Factory {
return instance;
}
void Register(const std::string &device_name, U &&creator) {
if (creators_.find(device_name) == creators_.end()) {
(void)creators_.emplace(device_name, creator);
void Register(U &&creator) { creators_.push_back(creator); }
std::shared_ptr<T> Create(enum DeviceType device_type) {
for (auto &item : creators_) {
MS_EXCEPTION_IF_NULL(item);
auto val = item();
if (val->CheckDeviceSupport(device_type)) {
return val;
}
}
}
bool CheckModelSupport(const std::string &device_name) {
return std::any_of(creators_.begin(), creators_.end(),
[&device_name](const std::pair<std::string, U> &item) { return item.first == device_name; });
}
std::shared_ptr<T> Create(const std::string &device_name) {
auto iter = creators_.find(device_name);
if (creators_.end() != iter) {
MS_EXCEPTION_IF_NULL(iter->second);
return (iter->second)();
}
MS_LOG(ERROR) << "Unsupported device target " << device_name;
MS_LOG(WARNING) << "Unsupported device target " << device_type;
return nullptr;
}
private:
Factory() = default;
~Factory() = default;
std::map<std::string, U> creators_;
std::vector<U> creators_;
};
template <class T>
@ -72,14 +88,12 @@ class Registrar {
using U = std::function<std::shared_ptr<T>()>;
public:
Registrar(const std::string &device_name, U creator) {
Factory<T>::Instance().Register(device_name, std::move(creator));
}
explicit Registrar(U creator) { Factory<T>::Instance().Register(std::move(creator)); }
~Registrar() = default;
};
#define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \
static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \
#DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); });
#define API_FACTORY_REG(BASE_CLASS, DERIVE_CLASS) \
static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_reg( \
[]() { return std::make_shared<DERIVE_CLASS>(); });
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H

View File

@ -18,9 +18,10 @@
#include "cxx_api/model/acl/model_converter.h"
#include "utils/log_adapter.h"
#include "mindspore/core/utils/convert_utils_base.h"
#include "cxx_api/acl_utils.h"
namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, Ascend310, AclGraphImpl);
API_FACTORY_REG(GraphCell::GraphImpl, AclGraphImpl);
AclGraphImpl::AclGraphImpl()
: init_flag_(false),
@ -231,4 +232,12 @@ Status AclGraphImpl::ConvertToOM() {
MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType();
return kMCFailed;
}
bool AclGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) {
// for Ascend, only support kAscend and kAscend310
if (device_type != kAscend && device_type != kAscend310) {
return false;
}
return IsAscendNo910Soc();
}
} // namespace mindspore

View File

@ -37,6 +37,7 @@ class AclGraphImpl : public GraphCell::GraphImpl {
Status Load(uint32_t device_id) override;
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
private:
Status ConvertToOM();

View File

@ -18,6 +18,7 @@
#include "include/api/context.h"
#include "cxx_api/factory.h"
#include "cxx_api/akg_kernel_register.h"
#include "cxx_api/acl_utils.h"
#include "utils/log_adapter.h"
#include "utils/context/context_extends.h"
#include "mindspore/core/base/base_ref_utils.h"
@ -30,7 +31,7 @@
#include "pybind11/pybind11.h"
namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl);
API_FACTORY_REG(GraphCell::GraphImpl, AscendGraphImpl);
static constexpr const char *kHcclEnable = "MS_ENABLE_HCCL";
static constexpr const char *kHcclGroupFile = "PARA_GROUP_FILE";
@ -382,6 +383,14 @@ std::shared_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::GetEnv
return acl_env;
}
bool AscendGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) {
// for Ascend, only support kAscend and kAscend910
if (device_type != kAscend && device_type != kAscend910) {
return false;
}
return IsAscend910Soc();
}
std::map<uint32_t, std::weak_ptr<AscendGraphImpl::MsEnvGuard>> AscendGraphImpl::MsEnvGuard::global_ms_env_;
std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_;

View File

@ -39,6 +39,7 @@ class AscendGraphImpl : public GraphCell::GraphImpl {
Status Load(uint32_t device_id) override;
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
private:
class MsEnvGuard;

View File

@ -26,7 +26,7 @@
#include "runtime/device/gpu/cuda_driver.h"
namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl);
API_FACTORY_REG(GraphCell::GraphImpl, GPUGraphImpl);
GPUGraphImpl::GPUGraphImpl()
: session_impl_(nullptr),
@ -291,4 +291,6 @@ std::vector<MSTensor> GPUGraphImpl::GetOutputs() {
}
return result;
}
bool GPUGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) { return device_type == kGPU; }
} // namespace mindspore

View File

@ -37,6 +37,8 @@ class GPUGraphImpl : public GraphCell::GraphImpl {
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
private:
Status InitEnv();
Status FinalizeEnv();

View File

@ -42,6 +42,8 @@ class GraphCell::GraphImpl {
virtual std::vector<MSTensor> GetInputs() = 0;
virtual std::vector<MSTensor> GetOutputs() = 0;
virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0;
protected:
std::shared_ptr<Graph> graph_;
std::shared_ptr<Context> graph_context_;

View File

@ -21,7 +21,7 @@
#include "include/api/context.h"
#include "cxx_api/factory.h"
#include "cxx_api/graph/acl/acl_env_guard.h"
#include "acl/acl_base.h"
#include "cxx_api/acl_utils.h"
namespace mindspore {
Status AclModel::Build() {
@ -112,7 +112,7 @@ Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<s
if (model_context_ == nullptr) {
model_context_ = std::make_shared<Context>();
model_context_->MutableDeviceInfo().emplace_back(std::make_shared<Ascend310DeviceInfo>());
model_context_->MutableDeviceInfo().emplace_back(std::make_shared<AscendDeviceInfo>());
}
std::string input_shape_option;
@ -139,7 +139,7 @@ Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<s
MS_LOG(ERROR) << "Invalid model context, only single device info is supported.";
return kMCInvalidArgs;
}
auto ascend310_info = device_infos[0]->Cast<Ascend310DeviceInfo>();
auto ascend310_info = device_infos[0]->Cast<AscendDeviceInfo>();
MS_EXCEPTION_IF_NULL(ascend310_info);
ascend310_info->SetInputShape(input_shape_option);
auto graph_cell_bak = std::move(graph_cell_);
@ -163,16 +163,15 @@ std::vector<MSTensor> AclModel::GetOutputs() {
return graph_cell_->GetOutputs();
}
bool AclModel::CheckModelSupport(enum ModelType model_type) {
const char *soc_name_c = aclrtGetSocName();
if (soc_name_c == nullptr) {
return false;
}
std::string soc_name(soc_name_c);
if (soc_name.find("910") != std::string::npos) {
bool AclModel::CheckDeviceSupport(mindspore::DeviceType device_type) {
// for Ascend, only support kAscend and kAscend310
if (device_type != kAscend && device_type != kAscend310) {
return false;
}
return IsAscendNo910Soc();
}
bool AclModel::CheckModelSupport(enum ModelType model_type) {
static const std::set<ModelType> kSupportedModelMap = {kMindIR, kOM};
auto iter = kSupportedModelMap.find(model_type);
if (iter == kSupportedModelMap.end()) {

View File

@ -43,6 +43,7 @@ class AclModel : public ModelImpl {
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
bool CheckModelSupport(enum ModelType model_type) override;
private:

View File

@ -29,7 +29,7 @@
#include "cxx_api/model/acl/acl_vm/acl_vm.h"
namespace mindspore {
API_FACTORY_REG(ModelImpl, Ascend310, AclModelMulti);
API_FACTORY_REG(ModelImpl, AclModelMulti);
namespace {
std::map<DataType, size_t> kDtypeMap = {

View File

@ -33,7 +33,7 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
if (device_infos.size() != 1) {
return;
}
auto ascend310_info = device_infos[0]->Cast<Ascend310DeviceInfo>();
auto ascend310_info = device_infos[0]->Cast<AscendDeviceInfo>();
if (ascend310_info == nullptr) {
return;
}

View File

@ -20,19 +20,6 @@
#include "utils/utils.h"
namespace mindspore {
namespace {
std::string GetDeviceTypeString(enum DeviceType type) {
static const std::map<enum DeviceType, std::string> kDeviceTypeStrs = {
{kCPU, "CPU"}, {kGPU, "GPU"}, {kKirinNPU, "KirinGPU"}, {kAscend910, "Ascend910"}, {kAscend310, "Ascend310"},
};
auto iter = kDeviceTypeStrs.find(type);
if (iter != kDeviceTypeStrs.end()) {
return iter->second;
}
return "InvalidDeviceType" + std::to_string(static_cast<int>(type));
}
} // namespace
Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_context,
const std::shared_ptr<TrainCfg> &) {
if (graph_cell.GetGraph() == nullptr) {
@ -50,7 +37,7 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_
return kMCInvalidInput;
}
std::string device_target = GetDeviceTypeString(device_info[0]->GetDeviceType());
auto device_target = device_info[0]->GetDeviceType();
impl_ = Factory<ModelImpl>::Instance().Create(device_target);
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Create session type " << device_target << " failed";
@ -175,16 +162,10 @@ Model::Model() : impl_(nullptr) {}
Model::~Model() {}
bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) {
std::string device_type_str = GetDeviceTypeString(device_type);
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type_str)) {
return false;
}
auto check_model = Factory<ModelImpl>::Instance().Create(device_type_str);
auto check_model = Factory<ModelImpl>::Instance().Create(device_type);
if (check_model == nullptr) {
return false;
}
return check_model->CheckModelSupport(model_type);
}

View File

@ -44,7 +44,8 @@ class ModelImpl {
virtual std::vector<MSTensor> GetInputs() = 0;
virtual std::vector<MSTensor> GetOutputs() = 0;
virtual bool CheckModelSupport(enum ModelType model_type) { return false; }
virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0;
virtual bool CheckModelSupport(enum ModelType model_type) = 0;
virtual Status Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);

View File

@ -20,14 +20,13 @@
#include "include/api/context.h"
#include "utils/ms_context.h"
#include "cxx_api/factory.h"
#if ENABLE_D
#include "cxx_api/acl_utils.h"
#endif
namespace mindspore {
// mindspore-serving check current package for version check with ModelImpl factory.
#if ENABLE_D
API_FACTORY_REG(ModelImpl, Ascend910, MsModel);
#elif ENABLE_GPU
API_FACTORY_REG(ModelImpl, GPU, MsModel);
#endif
API_FACTORY_REG(ModelImpl, MsModel);
static std::string GenerateShapeKey(const std::vector<std::vector<int64_t>> &dims) {
std::string shape_key;
@ -171,18 +170,23 @@ uint32_t MsModel::GetDeviceID() const {
return 0;
}
bool MsModel::CheckModelSupport(enum ModelType model_type) {
bool MsModel::CheckDeviceSupport(enum DeviceType device_type) {
#if ENABLE_D
const char *soc_name_c = aclrtGetSocName();
if (soc_name_c == nullptr) {
// for Ascend, only support kAscend or kAscend910
if (device_type != kAscend && device_type != kAscend910) {
return false;
}
std::string soc_name(soc_name_c);
if (soc_name.find("910") == std::string::npos) {
return IsAscend910Soc();
#else
// otherwise, only support GPU
if (device_type != kGPU) {
return false;
}
return true;
#endif
}
bool MsModel::CheckModelSupport(mindspore::ModelType model_type) {
static const std::set<ModelType> kSupportedModelMap = {kMindIR};
auto iter = kSupportedModelMap.find(model_type);
if (iter == kSupportedModelMap.end()) {

View File

@ -44,6 +44,7 @@ class MsModel : public ModelImpl {
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
bool CheckModelSupport(enum ModelType model_type) override;
private:

View File

@ -1,18 +1,18 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
#define MINDSPORE_INCLUDE_API_CONTEXT_H
@ -25,212 +25,431 @@
namespace mindspore {
enum DeviceType {
kCPU = 0,
kGPU,
kKirinNPU,
kAscend910,
kAscend310,
// add new type here
kInvalidDeviceType = 100,
kCPU = 0,
kGPU,
kKirinNPU,
kAscend,
kAscend910,
kAscend310,
// add new type here
kInvalidDeviceType = 100,
};
class Allocator;
class Delegate;
class DeviceInfoContext;
/// \brief Context is used to store environment variables during execution.
class MS_API Context {
public:
Context();
~Context() = default;
public:
struct Data;
Context();
~Context() = default;
void SetThreadNum(int32_t thread_num);
int32_t GetThreadNum() const;
/// \brief Set the number of threads at runtime. Only valid for Lite.
///
/// \param[in] thread_num the number of threads at runtime.
void SetThreadNum(int32_t thread_num);
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
std::shared_ptr<Allocator> GetAllocator() const;
/// \brief Get the current thread number setting. Only valid for Lite.
///
/// \return The current thread number setting.
int32_t GetThreadNum() const;
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
/// \brief Set the thread affinity to CPU cores. Only valid for Lite.
///
/// \param[in] mode: 0: no affinities, 1: big cores first, 2: little cores first
void SetThreadAffinity(int mode);
private:
struct Data;
std::shared_ptr<Data> data_;
/// \brief Get the thread affinity of CPU cores. Only valid for Lite.
///
/// \return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first
int GetThreadAffinityMode() const;
/// \brief Set the thread lists to CPU cores. Only valid for Lite.
///
/// \note If core_list and mode are set by SetThreadAffinity at the same time, the core_list is effective, but the
/// mode is not effective.
///
/// \param[in] core_list: a vector of thread core lists.
void SetThreadAffinity(const std::vector<int> &core_list);
/// \brief Get the thread lists of CPU cores. Only valid for Lite.
///
/// \return core_list: a vector of thread core lists.
std::vector<int32_t> GetThreadAffinityCoreList() const;
/// \brief Set the status whether to perform model inference or training in parallel. Only valid for Lite.
///
/// \param[in] is_parallel: true, parallel; false, not in parallel.
void SetEnableParallel(bool is_parallel);
/// \brief Get the status whether to perform model inference or training in parallel. Only valid for Lite.
///
/// \return Bool value that indicates whether in parallel.
bool GetEnableParallel() const;
/// \brief Set Delegate to access third-party AI framework. Only valid for Lite.
///
/// \param[in] Pointer to the custom delegate.
void SetDelegate(const std::shared_ptr<Delegate> &delegate);
/// \brief Get the delegate of the third-party AI framework. Only valid for Lite.
///
/// \return Pointer to the custom delegate.
std::shared_ptr<Delegate> GetDelegate() const;
/// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports
/// heterogeneous scenarios with multiple members in the vector.
///
/// \return Mutable reference of DeviceInfoContext vector in this context.
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
private:
std::shared_ptr<Data> data_;
};
/// \brief DeviceInfoContext defines different device contexts.
class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
public:
struct Data;
public:
struct Data;
DeviceInfoContext();
virtual ~DeviceInfoContext() = default;
virtual enum DeviceType GetDeviceType() const = 0;
DeviceInfoContext();
virtual ~DeviceInfoContext() = default;
template <class T>
std::shared_ptr<T> Cast() {
static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
if (GetDeviceType() != T().GetDeviceType()) {
return nullptr;
}
/// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
virtual enum DeviceType GetDeviceType() const = 0;
return std::static_pointer_cast<T>(shared_from_this());
}
/// \brief A similar function to RTTI is provided when the -fno-rtti compilation option is turned on, which converts
/// DeviceInfoContext to a shared pointer of type T, and returns nullptr if the conversion fails.
///
/// \param T Type
/// \return A pointer of type T after conversion. If the conversion fails, it will be nullptr.
template <class T>
std::shared_ptr<T> Cast() {
static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
if (GetDeviceType() != T().GetDeviceType()) {
return nullptr;
}
protected:
std::shared_ptr<Data> data_;
return std::static_pointer_cast<T>(shared_from_this());
}
/// \brief obtain provider's name
///
/// \return provider's name.
std::string GetProvider() const;
/// \brief set provider's name.
///
/// \param[in] provider define the provider's name.
void SetProvider(const std::string &provider);
/// \brief obtain provider's device type.
///
/// \return provider's device type.
std::string GetProviderDevice() const;
/// \brief set provider's device type.
///
/// \param[in] device define the provider's device type.EG: CPU.
void SetProviderDevice(const std::string &device);
/// \brief set memory allocator.
///
/// \param[in] allocator define the memory allocator which can be defined by user.
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
/// \brief obtain memory allocator.
///
/// \return memory allocator.
std::shared_ptr<Allocator> GetAllocator() const;
protected:
std::shared_ptr<Data> data_;
};
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the CPU. This option is only valid
/// for MindSpore Lite.
class MS_API CPUDeviceInfo : public DeviceInfoContext {
public:
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
public:
/// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
/// \brief Set the thread affinity to CPU cores.
///
/// \param mode: 0: no affinities, 1: big cores first, 2: little cores first
void SetThreadAffinity(int mode);
int GetThreadAffinity() const;
void SetEnableFP16(bool is_fp16);
bool GetEnableFP16() const;
/// \brief Set enables to perform the float16 inference
///
/// \param[in] is_fp16 Enable float16 inference or not.
void SetEnableFP16(bool is_fp16);
/// \brief Get enables to perform the float16 inference
///
/// \return Whether enable float16 inference.
bool GetEnableFP16() const;
};
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the NPU. This option is only valid
/// for MindSpore Lite.
class MS_API KirinNPUDeviceInfo : public DeviceInfoContext {
public:
enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
public:
/// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
void SetFrequency(int frequency);
int GetFrequency() const;
/// \brief Set the NPU frequency.
///
/// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme
/// performance), default as 3.
void SetFrequency(int frequency);
/// \brief Get the NPU frequency.
///
/// \return NPU frequency
int GetFrequency() const;
};
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the GPU.
class MS_API GPUDeviceInfo : public DeviceInfoContext {
public:
enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; };
public:
/// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; };
void SetDeviceID(uint32_t device_id);
uint32_t GetDeviceID() const;
/// \brief Set device id.
///
/// \param[in] device_id The device id.
void SetDeviceID(uint32_t device_id);
void SetGpuTrtInferMode(bool gpu_trt_infer_mode);
bool GetGpuTrtInferMode() const;
/// \brief Get the device id.
///
/// \return The device id.
uint32_t GetDeviceID() const;
void SetEnableFP16(bool is_fp16);
bool GetEnableFP16() const;
/// \brief Get the distribution rank id.
///
/// \return The device id.
int GetRankID() const;
/// \brief Get the distribution group size.
///
/// \return The device id.
int GetGroupSize() const;
/// \brief Set the precision mode.
///
/// \param[in] precision_mode Optional "origin", "fp16". "origin" is set as default.
inline void SetPrecisionMode(const std::string &precision_mode);
/// \brief Get the precision mode.
///
/// \return The precision mode.
inline std::string GetPrecisionMode() const;
/// \brief Set enables to perform the float16 inference
///
/// \param[in] is_fp16 Enable float16 inference or not.
void SetEnableFP16(bool is_fp16);
/// \brief Get enables to perform the float16 inference
///
/// \return Whether enable float16 inference.
bool GetEnableFP16() const;
private:
void SetPrecisionMode(const std::vector<char> &precision_mode);
std::vector<char> GetPrecisionModeChar() const;
};
class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
public:
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
SetPrecisionMode(StringToChar(precision_mode));
}
std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
void SetDeviceID(uint32_t device_id);
uint32_t GetDeviceID() const;
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is
/// invalid for MindSpore Lite.
class MS_API AscendDeviceInfo : public DeviceInfoContext {
public:
/// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; };
/// \brief Set device id.
///
/// \param[in] device_id The device id.
void SetDeviceID(uint32_t device_id);
/// \brief Get the device id.
///
/// \return The device id.
uint32_t GetDeviceID() const;
/// \brief Set AIPP configuration file path.
///
/// \param[in] cfg_path AIPP configuration file path.
inline void SetInsertOpConfigPath(const std::string &cfg_path);
/// \brief Get AIPP configuration file path.
///
/// \return AIPP configuration file path.
inline std::string GetInsertOpConfigPath() const;
/// \brief Set format of model inputs.
///
/// \param[in] format Optional "NCHW", "NHWC", etc.
inline void SetInputFormat(const std::string &format);
/// \brief Get format of model inputs.
///
/// \return The format of model inputs.
inline std::string GetInputFormat() const;
/// \brief Set shape of model inputs.
///
/// \param[in] shape e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1".
inline void SetInputShape(const std::string &shape);
/// \brief Get shape of model inputs.
///
/// \return The shape of model inputs.
inline std::string GetInputShape() const;
/// \brief Set shape of model inputs.
///
/// \param[in] shape e.g. {{1, {1,2,3,4}}, {2, {4,3,2,1}}} means the first input shape 1,2,3,4 and the second input
/// shape 4,3,2,1.
void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
/// \brief Get shape of model inputs.
///
/// \return The shape of model inputs.
std::map<int, std::vector<int>> GetInputShapeMap() const;
void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
inline std::string GetDynamicBatchSize() const;
/// \brief Set the dynamic image size of model inputs.
///
/// \param[in] image size hw e.g. "66,88;32,64" means h1:66,w1:88; h2:32,w2:64.
inline void SetDynamicImageSize(const std::string &dynamic_image_size);
/// \brief Get dynamic image size of model inputs.
///
/// \return The image size of model inputs.
inline std::string GetDynamicImageSize() const;
/// \brief Set type of model outputs.
///
/// \param[in] output_type FP32, UINT8 or FP16, default as FP32.
void SetOutputType(enum DataType output_type);
/// \brief Get type of model outputs.
///
/// \return The set type of model outputs.
enum DataType GetOutputType() const;
/// \brief Set precision mode of model.
///
/// \param[in] precision_mode Optional "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" and
/// "allow_mix_precision", "force_fp16" is set as default
inline void SetPrecisionMode(const std::string &precision_mode);
/// \brief Get precision mode of model.
///
/// \return The set type of model outputs
inline std::string GetPrecisionMode() const;
/// \brief Set op select implementation mode.
///
/// \param[in] op_select_impl_mode Optional "high_performance" and "high_precision", "high_performance" is set as
/// default.
inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
/// \brief Get op select implementation mode.
///
/// \return The set op select implementation mode.
inline std::string GetOpSelectImplMode() const;
inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
inline std::string GetFusionSwitchConfigPath() const;
// Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize"
inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode);
inline std::string GetBufferOptimizeMode() const;
private:
void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetInsertOpConfigPathChar() const;
void SetInputFormat(const std::vector<char> &format);
std::vector<char> GetInputFormatChar() const;
void SetInputShape(const std::vector<char> &shape);
std::vector<char> GetInputShapeChar() const;
std::vector<char> GetDynamicBatchSizeChar() const;
void SetDynamicImageSize(const std::vector<char> &dynamic_image_size);
std::vector<char> GetDynamicImageSizeChar() const;
void SetPrecisionMode(const std::vector<char> &precision_mode);
std::vector<char> GetPrecisionModeChar() const;
void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
std::vector<char> GetOpSelectImplModeChar() const;
void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetFusionSwitchConfigPathChar() const;
void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode);
std::vector<char> GetBufferOptimizeModeChar() const;
};
class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
public:
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };
using Ascend310DeviceInfo = AscendDeviceInfo;
using Ascend910DeviceInfo = AscendDeviceInfo;
void SetDeviceID(uint32_t device_id);
uint32_t GetDeviceID() const;
inline void SetDumpConfigPath(const std::string &cfg_path);
inline std::string GetDumpConfigPath() const;
// aipp config file
inline void SetInsertOpConfigPath(const std::string &cfg_path);
inline std::string GetInsertOpConfigPath() const;
// nchw or nhwc
inline void SetInputFormat(const std::string &format);
inline std::string GetInputFormat() const;
// Mandatory while dynamic batch: e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1"
inline void SetInputShape(const std::string &shape);
inline std::string GetInputShape() const;
void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
std::map<int, std::vector<int>> GetInputShapeMap() const;
void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
inline std::string GetDynamicBatchSize() const;
// FP32, UINT8 or FP16, default as FP32
void SetOutputType(enum DataType output_type);
enum DataType GetOutputType() const;
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype", default as "force_fp16"
inline void SetPrecisionMode(const std::string &precision_mode);
inline std::string GetPrecisionMode() const;
// Optional "high_performance" and "high_precision", "high_performance" is set as default
inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
inline std::string GetOpSelectImplMode() const;
inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
inline std::string GetFusionSwitchConfigPath() const;
// Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize"
inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode);
inline std::string GetBufferOptimizeMode() const;
private:
void SetDumpConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetDumpConfigPathChar() const;
void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetInsertOpConfigPathChar() const;
void SetInputFormat(const std::vector<char> &format);
std::vector<char> GetInputFormatChar() const;
void SetInputShape(const std::vector<char> &shape);
std::vector<char> GetInputShapeChar() const;
std::vector<char> GetDynamicBatchSizeChar() const;
void SetPrecisionMode(const std::vector<char> &precision_mode);
std::vector<char> GetPrecisionModeChar() const;
void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
std::vector<char> GetOpSelectImplModeChar() const;
void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetFusionSwitchConfigPathChar() const;
void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode);
std::vector<char> GetBufferOptimizeModeChar() const;
};
void Ascend310DeviceInfo::SetDumpConfigPath(const std::string &cfg_path) { SetDumpConfigPath(StringToChar(cfg_path)); }
std::string Ascend310DeviceInfo::GetDumpConfigPath() const { return CharToString(GetDumpConfigPathChar()); }
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
SetInsertOpConfigPath(StringToChar(cfg_path));
void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
SetInsertOpConfigPath(StringToChar(cfg_path));
}
std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
SetPrecisionMode(StringToChar(precision_mode));
}
std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
SetOpSelectImplMode(StringToChar(op_select_impl_mode));
}
std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
SetFusionSwitchConfigPath(StringToChar(cfg_path));
}
std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const {
return CharToString(GetFusionSwitchConfigPathChar());
void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) {
SetDynamicImageSize(StringToChar(dynamic_image_size));
}
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); }
void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
SetPrecisionMode(StringToChar(precision_mode));
}
std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
SetOpSelectImplMode(StringToChar(op_select_impl_mode));
}
std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
SetFusionSwitchConfigPath(StringToChar(cfg_path));
}
std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const {
return CharToString(GetFusionSwitchConfigPathChar());
}
void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
}
std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H

View File

@ -317,13 +317,7 @@ std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
return ret;
}
void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) { MS_LOG(ERROR) << "Unsupported Feature."; }
uint32_t Ascend910DeviceInfo::GetDeviceID() const {
MS_LOG(ERROR) << "Unsupported Feature.";
return 0;
}
void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
void AscendDeviceInfo::SetDeviceID(uint32_t device_id) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
@ -331,7 +325,7 @@ void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
data_->params[kModelOptionAscend310DeviceID] = device_id;
}
uint32_t Ascend310DeviceInfo::GetDeviceID() const {
uint32_t AscendDeviceInfo::GetDeviceID() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return 0;
@ -339,14 +333,14 @@ uint32_t Ascend310DeviceInfo::GetDeviceID() const {
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
}
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
}
std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
@ -355,7 +349,7 @@ std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
void AscendDeviceInfo::SetInputFormat(const std::vector<char> &format) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
@ -363,7 +357,7 @@ void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
}
std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
std::vector<char> AscendDeviceInfo::GetInputFormatChar() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
@ -372,14 +366,14 @@ std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) {
void AscendDeviceInfo::SetInputShape(const std::vector<char> &shape) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
}
std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
std::vector<char> AscendDeviceInfo::GetInputShapeChar() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
@ -388,7 +382,7 @@ std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
void AscendDeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
@ -403,7 +397,7 @@ void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic
data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
}
std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
@ -412,7 +406,7 @@ std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) {
void AscendDeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
@ -420,7 +414,7 @@ void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_i
data_->params[kModelOptionAscend310DynamicImageSize] = CharToString(dynamic_image_size);
}
std::vector<char> Ascend310DeviceInfo::GetDynamicImageSizeChar() const {
std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
@ -429,7 +423,7 @@ std::vector<char> Ascend310DeviceInfo::GetDynamicImageSizeChar() const {
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
void AscendDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
@ -437,7 +431,7 @@ void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mo
data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
}
std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
@ -446,7 +440,7 @@ std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
void AscendDeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
@ -454,7 +448,7 @@ void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select
data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
}
std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
@ -463,14 +457,14 @@ std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
}
std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
@ -479,7 +473,7 @@ std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
@ -487,7 +481,7 @@ void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>>
data_->params[kModelOptionAscend310InputShapeMap] = shape;
}
std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::map<int, std::vector<int>>();
@ -495,7 +489,7 @@ std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
}
void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
void AscendDeviceInfo::SetOutputType(enum DataType output_type) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
@ -503,7 +497,7 @@ void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
data_->params[kModelOptionAscend310OutputType] = output_type;
}
enum DataType Ascend310DeviceInfo::GetOutputType() const {
enum DataType AscendDeviceInfo::GetOutputType() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return DataType::kTypeUnknown;
@ -511,7 +505,7 @@ enum DataType Ascend310DeviceInfo::GetOutputType() const {
return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
}
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
@ -519,7 +513,7 @@ void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_
data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
}
std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const {
std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();

View File

@ -104,7 +104,7 @@ lite::InnerContext *ContextUtils::Convert(Context *context) {
} else if (device->GetDeviceType() == kKirinNPU) {
auto npu_context = device->Cast<KirinNPUDeviceInfo>();
ret = AddNpuDevice(npu_context->GetFrequency(), inner_context.get());
} else if (device->GetDeviceType() == kAscend310) {
} else if (device->GetDeviceType() == kAscend) {
ret = AddAscend310Device(inner_context.get(), device.get());
}
if (ret != kSuccess) {

View File

@ -71,11 +71,11 @@ std::shared_ptr<mindspore::Context> Common::ContextAutoSet() {
auto context = std::make_shared<mindspore::Context>();
if (device_target_str == "Ascend310") {
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
auto ascend310_info = std::make_shared<mindspore::AscendDeviceInfo>();
ascend310_info->SetDeviceID(device_id);
context->MutableDeviceInfo().emplace_back(ascend310_info);
} else if (device_target_str == "Ascend910") {
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
auto ascend310_info = std::make_shared<mindspore::AscendDeviceInfo>();
ascend310_info->SetDeviceID(device_id);
context->MutableDeviceInfo().emplace_back(ascend310_info);
} else {

View File

@ -101,7 +101,7 @@ TEST_F(TestDE, TestDvpp) {
auto context = ContextAutoSet();
ASSERT_TRUE(context != nullptr);
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
ASSERT_TRUE(ascend310_info != nullptr);
auto device_id = ascend310_info->GetDeviceID();
@ -154,7 +154,7 @@ TEST_F(TestDE, TestDvppSinkMode) {
auto context = ContextAutoSet();
ASSERT_TRUE(context != nullptr);
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
ASSERT_TRUE(ascend310_info != nullptr);
auto device_id = ascend310_info->GetDeviceID();
@ -202,7 +202,7 @@ TEST_F(TestDE, TestDvppDecodeResizeCropNormalize) {
auto context = ContextAutoSet();
ASSERT_TRUE(context != nullptr);
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
ASSERT_TRUE(ascend310_info != nullptr);
auto device_id = ascend310_info->GetDeviceID();

View File

@ -38,7 +38,7 @@ TEST_F(TestDynamicBatchSize, InferMindIR) {
auto context = ContextAutoSet();
ASSERT_TRUE(context != nullptr);
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
ASSERT_TRUE(ascend310_info != nullptr);
std::map<int, std::vector<int>> input_shape;

View File

@ -59,7 +59,7 @@ TEST_F(TestZeroCopy, TestMindIR) {
auto context = ContextAutoSet();
ASSERT_TRUE(context != nullptr);
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
ASSERT_TRUE(ascend310_info != nullptr);
ascend310_info->SetInsertOpConfigPath(aipp_path);
auto device_id = ascend310_info->GetDeviceID();
@ -107,7 +107,7 @@ TEST_F(TestZeroCopy, TestDeviceTensor) {
auto context = ContextAutoSet();
ASSERT_TRUE(context != nullptr);
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
ASSERT_TRUE(ascend310_info != nullptr);
ascend310_info->SetInsertOpConfigPath(aipp_path);
auto device_id = ascend310_info->GetDeviceID();

View File

@ -27,32 +27,27 @@ TEST_F(TestCxxApiContext, test_context_device_info_cast_SUCCESS) {
std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>();
std::shared_ptr<DeviceInfoContext> gpu = std::make_shared<GPUDeviceInfo>();
std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>();
std::shared_ptr<DeviceInfoContext> ascend310 = std::make_shared<Ascend310DeviceInfo>();
std::shared_ptr<DeviceInfoContext> ascend910 = std::make_shared<Ascend910DeviceInfo>();
std::shared_ptr<DeviceInfoContext> ascend = std::make_shared<AscendDeviceInfo>();
ASSERT_TRUE(cpu->Cast<CPUDeviceInfo>() != nullptr);
ASSERT_TRUE(gpu->Cast<GPUDeviceInfo>() != nullptr);
ASSERT_TRUE(kirin_npu->Cast<KirinNPUDeviceInfo>() != nullptr);
ASSERT_TRUE(ascend310->Cast<Ascend310DeviceInfo>() != nullptr);
ASSERT_TRUE(ascend910->Cast<Ascend910DeviceInfo>() != nullptr);
ASSERT_TRUE(ascend->Cast<AscendDeviceInfo>() != nullptr);
}
TEST_F(TestCxxApiContext, test_context_device_info_cast_FAILED) {
std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>();
std::shared_ptr<DeviceInfoContext> gpu = std::make_shared<GPUDeviceInfo>();
std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>();
std::shared_ptr<DeviceInfoContext> ascend310 = std::make_shared<Ascend310DeviceInfo>();
std::shared_ptr<DeviceInfoContext> ascend910 = std::make_shared<Ascend910DeviceInfo>();
std::shared_ptr<DeviceInfoContext> ascend = std::make_shared<AscendDeviceInfo>();
ASSERT_TRUE(cpu->Cast<GPUDeviceInfo>() == nullptr);
ASSERT_TRUE(kirin_npu->Cast<GPUDeviceInfo>() == nullptr);
ASSERT_TRUE(ascend310->Cast<GPUDeviceInfo>() == nullptr);
ASSERT_TRUE(ascend910->Cast<GPUDeviceInfo>() == nullptr);
ASSERT_TRUE(ascend->Cast<GPUDeviceInfo>() == nullptr);
ASSERT_TRUE(gpu->Cast<CPUDeviceInfo>() == nullptr);
ASSERT_TRUE(kirin_npu->Cast<CPUDeviceInfo>() == nullptr);
ASSERT_TRUE(ascend310->Cast<CPUDeviceInfo>() == nullptr);
ASSERT_TRUE(ascend910->Cast<CPUDeviceInfo>() == nullptr);
ASSERT_TRUE(ascend->Cast<CPUDeviceInfo>() == nullptr);
}
TEST_F(TestCxxApiContext, test_context_get_set_SUCCESS) {
@ -86,7 +81,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) {
std::string option_9_ans = "1,2,3,4,5";
auto context = std::make_shared<Context>();
std::shared_ptr<Ascend310DeviceInfo> ascend310 = std::make_shared<Ascend310DeviceInfo>();
std::shared_ptr<AscendDeviceInfo> ascend310 = std::make_shared<AscendDeviceInfo>();
ascend310->SetInputShape(option_1);
ascend310->SetInsertOpConfigPath(option_2);
ascend310->SetOpSelectImplMode(option_3);
@ -99,7 +94,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) {
context->MutableDeviceInfo().push_back(ascend310);
ASSERT_EQ(context->MutableDeviceInfo().size(), 1);
auto ctx = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
auto ctx = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
ASSERT_TRUE(ctx != nullptr);
ASSERT_EQ(ascend310->GetInputShape(), option_1);
ASSERT_EQ(ascend310->GetInsertOpConfigPath(), option_2);
@ -113,7 +108,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) {
}
TEST_F(TestCxxApiContext, test_context_ascend310_context_default_value_SUCCESS) {
auto ctx = std::make_shared<Ascend310DeviceInfo>();
auto ctx = std::make_shared<AscendDeviceInfo>();
ASSERT_EQ(ctx->GetOpSelectImplMode(), "");
}
} // namespace mindspore