forked from mindspore-Ecosystem/mindspore
C++ inference interface Ascend310DeviceInfo+Ascend910DeviceInfo->AscendDeviceInfo
This commit is contained in:
parent
87674cf3bd
commit
0a0419498f
|
@ -28,6 +28,7 @@ enum DeviceType {
|
||||||
kCPU = 0,
|
kCPU = 0,
|
||||||
kGPU,
|
kGPU,
|
||||||
kKirinNPU,
|
kKirinNPU,
|
||||||
|
kAscend,
|
||||||
kAscend910,
|
kAscend910,
|
||||||
kAscend310,
|
kAscend310,
|
||||||
// add new type here
|
// add new type here
|
||||||
|
@ -287,34 +288,14 @@ void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
||||||
}
|
}
|
||||||
std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
||||||
|
|
||||||
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend910. This option is
|
|
||||||
/// invalid for MindSpore Lite.
|
|
||||||
class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
|
|
||||||
public:
|
|
||||||
/// \brief Get the type of this DeviceInfoContext.
|
|
||||||
///
|
|
||||||
/// \return Type of this DeviceInfoContext.
|
|
||||||
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
|
|
||||||
|
|
||||||
/// \brief Set device id.
|
|
||||||
///
|
|
||||||
/// \param[in] device_id The device id.
|
|
||||||
void SetDeviceID(uint32_t device_id);
|
|
||||||
|
|
||||||
/// \brief Get the device id.
|
|
||||||
///
|
|
||||||
/// \return The device id.
|
|
||||||
uint32_t GetDeviceID() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is
|
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is
|
||||||
/// invalid for MindSpore Lite.
|
/// invalid for MindSpore Lite.
|
||||||
class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
|
class MS_API AscendDeviceInfo : public DeviceInfoContext {
|
||||||
public:
|
public:
|
||||||
/// \brief Get the type of this DeviceInfoContext.
|
/// \brief Get the type of this DeviceInfoContext.
|
||||||
///
|
///
|
||||||
/// \return Type of this DeviceInfoContext.
|
/// \return Type of this DeviceInfoContext.
|
||||||
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };
|
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; };
|
||||||
|
|
||||||
/// \brief Set device id.
|
/// \brief Set device id.
|
||||||
///
|
///
|
||||||
|
@ -447,45 +428,48 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
|
||||||
std::vector<char> GetBufferOptimizeModeChar() const;
|
std::vector<char> GetBufferOptimizeModeChar() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
|
using Ascend310DeviceInfo = AscendDeviceInfo;
|
||||||
|
using Ascend910DeviceInfo = AscendDeviceInfo;
|
||||||
|
|
||||||
|
void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
|
||||||
SetInsertOpConfigPath(StringToChar(cfg_path));
|
SetInsertOpConfigPath(StringToChar(cfg_path));
|
||||||
}
|
}
|
||||||
std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
|
std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
|
void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
|
||||||
std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
|
std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
|
void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
|
||||||
std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
|
std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
|
||||||
|
|
||||||
std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
|
std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) {
|
void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) {
|
||||||
SetDynamicImageSize(StringToChar(dynamic_image_size));
|
SetDynamicImageSize(StringToChar(dynamic_image_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Ascend310DeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); }
|
std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); }
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
||||||
SetPrecisionMode(StringToChar(precision_mode));
|
SetPrecisionMode(StringToChar(precision_mode));
|
||||||
}
|
}
|
||||||
std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
|
void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
|
||||||
SetOpSelectImplMode(StringToChar(op_select_impl_mode));
|
SetOpSelectImplMode(StringToChar(op_select_impl_mode));
|
||||||
}
|
}
|
||||||
std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
|
std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
|
void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
|
||||||
SetFusionSwitchConfigPath(StringToChar(cfg_path));
|
SetFusionSwitchConfigPath(StringToChar(cfg_path));
|
||||||
}
|
}
|
||||||
std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const {
|
std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const {
|
||||||
return CharToString(GetFusionSwitchConfigPathChar());
|
return CharToString(GetFusionSwitchConfigPathChar());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
|
void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
|
||||||
SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
|
SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
|
||||||
}
|
}
|
||||||
std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
|
std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H
|
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||||
|
|
|
@ -193,7 +193,7 @@ class MS_API Model {
|
||||||
|
|
||||||
/// \brief Inference model.
|
/// \brief Inference model.
|
||||||
///
|
///
|
||||||
/// \param[in] device_type Device type,options are kGPU, kAscend910, etc.
|
/// \param[in] device_type Device type,options are kGPU, kAscend, kAscend910, etc.
|
||||||
/// \param[in] model_type The type of model file, options are ModelType::kMindIR, ModelType::kOM.
|
/// \param[in] model_type The type of model file, options are ModelType::kMindIR, ModelType::kOM.
|
||||||
///
|
///
|
||||||
/// \return Is supported or not.
|
/// \return Is supported or not.
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H
|
||||||
|
#define MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "acl/acl_base.h"
|
||||||
|
namespace mindspore {
|
||||||
|
static inline bool IsAscend910Soc() {
|
||||||
|
const char *soc_name_c = aclrtGetSocName();
|
||||||
|
if (soc_name_c == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::string soc_name(soc_name_c);
|
||||||
|
if (soc_name.find("910") == std::string::npos) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool IsAscendNo910Soc() {
|
||||||
|
const char *soc_name_c = aclrtGetSocName();
|
||||||
|
if (soc_name_c == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::string soc_name(soc_name_c);
|
||||||
|
if (soc_name.find("910") != std::string::npos) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_CXX_API_ACL_UTILS_H
|
|
@ -175,55 +175,46 @@ std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) {
|
void AscendDeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
|
||||||
data_->params[kModelOptionAscend910DeviceID] = device_id;
|
|
||||||
}
|
|
||||||
uint32_t Ascend910DeviceInfo::GetDeviceID() const {
|
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
|
||||||
return GetValue<uint32_t>(data_, kModelOptionAscend910DeviceID);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
|
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
data_->params[kModelOptionAscend310DeviceID] = device_id;
|
data_->params[kModelOptionAscend310DeviceID] = device_id;
|
||||||
}
|
}
|
||||||
uint32_t Ascend310DeviceInfo::GetDeviceID() const {
|
uint32_t AscendDeviceInfo::GetDeviceID() const {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
|
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
|
void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
|
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
|
||||||
}
|
}
|
||||||
std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
|
std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InsertOpCfgPath);
|
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InsertOpCfgPath);
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
|
void AscendDeviceInfo::SetInputFormat(const std::vector<char> &format) {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
|
data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
|
||||||
}
|
}
|
||||||
std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
|
std::vector<char> AscendDeviceInfo::GetInputFormatChar() const {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputFormat);
|
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputFormat);
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) {
|
void AscendDeviceInfo::SetInputShape(const std::vector<char> &shape) {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
|
data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
|
||||||
}
|
}
|
||||||
std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
|
std::vector<char> AscendDeviceInfo::GetInputShapeChar() const {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputShape);
|
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputShape);
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
|
void AscendDeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
std::string batchs = "";
|
std::string batchs = "";
|
||||||
for (size_t i = 0; i < dynamic_batch_size.size(); ++i) {
|
for (size_t i = 0; i < dynamic_batch_size.size(); ++i) {
|
||||||
|
@ -234,69 +225,69 @@ void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic
|
||||||
}
|
}
|
||||||
data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
|
data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
|
||||||
}
|
}
|
||||||
std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
|
std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicBatchSize);
|
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicBatchSize);
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) { return; }
|
void AscendDeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) { return; }
|
||||||
|
|
||||||
std::vector<char> Ascend310DeviceInfo::GetDynamicImageSizeChar() const { return std::vector<char>(); }
|
std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const { return std::vector<char>(); }
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
void AscendDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
|
data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
|
||||||
}
|
}
|
||||||
std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
|
std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310PrecisionMode);
|
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310PrecisionMode);
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
|
void AscendDeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
|
data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
|
||||||
}
|
}
|
||||||
std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
|
std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310OpSelectImplMode);
|
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310OpSelectImplMode);
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
|
void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
|
data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
|
||||||
}
|
}
|
||||||
std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
|
std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
const std::string &ref = GetValue<std::string>(data_, KModelOptionAscend310FusionSwitchCfgPath);
|
const std::string &ref = GetValue<std::string>(data_, KModelOptionAscend310FusionSwitchCfgPath);
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
|
void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
data_->params[kModelOptionAscend310InputShapeMap] = shape;
|
data_->params[kModelOptionAscend310InputShapeMap] = shape;
|
||||||
}
|
}
|
||||||
std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
|
std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
|
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
|
void AscendDeviceInfo::SetOutputType(enum DataType output_type) {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
data_->params[kModelOptionAscend310OutputType] = output_type;
|
data_->params[kModelOptionAscend310OutputType] = output_type;
|
||||||
}
|
}
|
||||||
enum DataType Ascend310DeviceInfo::GetOutputType() const {
|
enum DataType AscendDeviceInfo::GetOutputType() const {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
|
return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
|
void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
|
data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
|
||||||
}
|
}
|
||||||
std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const {
|
std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310BufferOptimize);
|
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310BufferOptimize);
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
|
|
|
@ -24,7 +24,31 @@
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
inline std::string g_device_target = "Default";
|
inline enum DeviceType g_device_target = kInvalidDeviceType;
|
||||||
|
|
||||||
|
static inline LogStream &operator<<(LogStream &stream, DeviceType device_type) {
|
||||||
|
switch (device_type) {
|
||||||
|
case kAscend:
|
||||||
|
stream << "Ascend";
|
||||||
|
break;
|
||||||
|
case kAscend910:
|
||||||
|
stream << "Ascend910";
|
||||||
|
break;
|
||||||
|
case kAscend310:
|
||||||
|
stream << "Ascend310";
|
||||||
|
break;
|
||||||
|
case kGPU:
|
||||||
|
stream << "GPU";
|
||||||
|
break;
|
||||||
|
case kCPU:
|
||||||
|
stream << "CPU";
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
stream << "[InvalidDeviceType: " << static_cast<int>(device_type) << "]";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
class Factory {
|
class Factory {
|
||||||
|
@ -39,32 +63,24 @@ class Factory {
|
||||||
return instance;
|
return instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Register(const std::string &device_name, U &&creator) {
|
void Register(U &&creator) { creators_.push_back(creator); }
|
||||||
if (creators_.find(device_name) == creators_.end()) {
|
|
||||||
(void)creators_.emplace(device_name, creator);
|
std::shared_ptr<T> Create(enum DeviceType device_type) {
|
||||||
|
for (auto &item : creators_) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
auto val = item();
|
||||||
|
if (val->CheckDeviceSupport(device_type)) {
|
||||||
|
return val;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
MS_LOG(WARNING) << "Unsupported device target " << device_type;
|
||||||
|
|
||||||
bool CheckModelSupport(const std::string &device_name) {
|
|
||||||
return std::any_of(creators_.begin(), creators_.end(),
|
|
||||||
[&device_name](const std::pair<std::string, U> &item) { return item.first == device_name; });
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<T> Create(const std::string &device_name) {
|
|
||||||
auto iter = creators_.find(device_name);
|
|
||||||
if (creators_.end() != iter) {
|
|
||||||
MS_EXCEPTION_IF_NULL(iter->second);
|
|
||||||
return (iter->second)();
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_LOG(ERROR) << "Unsupported device target " << device_name;
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Factory() = default;
|
Factory() = default;
|
||||||
~Factory() = default;
|
~Factory() = default;
|
||||||
std::map<std::string, U> creators_;
|
std::vector<U> creators_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
|
@ -72,14 +88,12 @@ class Registrar {
|
||||||
using U = std::function<std::shared_ptr<T>()>;
|
using U = std::function<std::shared_ptr<T>()>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Registrar(const std::string &device_name, U creator) {
|
explicit Registrar(U creator) { Factory<T>::Instance().Register(std::move(creator)); }
|
||||||
Factory<T>::Instance().Register(device_name, std::move(creator));
|
|
||||||
}
|
|
||||||
~Registrar() = default;
|
~Registrar() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \
|
#define API_FACTORY_REG(BASE_CLASS, DERIVE_CLASS) \
|
||||||
static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \
|
static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_reg( \
|
||||||
#DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); });
|
[]() { return std::make_shared<DERIVE_CLASS>(); });
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H
|
#endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H
|
||||||
|
|
|
@ -18,9 +18,10 @@
|
||||||
#include "cxx_api/model/acl/model_converter.h"
|
#include "cxx_api/model/acl/model_converter.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "mindspore/core/utils/convert_utils_base.h"
|
#include "mindspore/core/utils/convert_utils_base.h"
|
||||||
|
#include "cxx_api/acl_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
API_FACTORY_REG(GraphCell::GraphImpl, Ascend310, AclGraphImpl);
|
API_FACTORY_REG(GraphCell::GraphImpl, AclGraphImpl);
|
||||||
|
|
||||||
AclGraphImpl::AclGraphImpl()
|
AclGraphImpl::AclGraphImpl()
|
||||||
: init_flag_(false),
|
: init_flag_(false),
|
||||||
|
@ -231,4 +232,12 @@ Status AclGraphImpl::ConvertToOM() {
|
||||||
MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType();
|
MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType();
|
||||||
return kMCFailed;
|
return kMCFailed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool AclGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) {
|
||||||
|
// for Ascend, only support kAscend and kAscend310
|
||||||
|
if (device_type != kAscend && device_type != kAscend310) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return IsAscendNo910Soc();
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -37,6 +37,7 @@ class AclGraphImpl : public GraphCell::GraphImpl {
|
||||||
Status Load(uint32_t device_id) override;
|
Status Load(uint32_t device_id) override;
|
||||||
std::vector<MSTensor> GetInputs() override;
|
std::vector<MSTensor> GetInputs() override;
|
||||||
std::vector<MSTensor> GetOutputs() override;
|
std::vector<MSTensor> GetOutputs() override;
|
||||||
|
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status ConvertToOM();
|
Status ConvertToOM();
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include "include/api/context.h"
|
#include "include/api/context.h"
|
||||||
#include "cxx_api/factory.h"
|
#include "cxx_api/factory.h"
|
||||||
#include "cxx_api/akg_kernel_register.h"
|
#include "cxx_api/akg_kernel_register.h"
|
||||||
|
#include "cxx_api/acl_utils.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "utils/context/context_extends.h"
|
#include "utils/context/context_extends.h"
|
||||||
#include "mindspore/core/base/base_ref_utils.h"
|
#include "mindspore/core/base/base_ref_utils.h"
|
||||||
|
@ -30,7 +31,7 @@
|
||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl);
|
API_FACTORY_REG(GraphCell::GraphImpl, AscendGraphImpl);
|
||||||
|
|
||||||
static constexpr const char *kHcclEnable = "MS_ENABLE_HCCL";
|
static constexpr const char *kHcclEnable = "MS_ENABLE_HCCL";
|
||||||
static constexpr const char *kHcclGroupFile = "PARA_GROUP_FILE";
|
static constexpr const char *kHcclGroupFile = "PARA_GROUP_FILE";
|
||||||
|
@ -382,6 +383,14 @@ std::shared_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::GetEnv
|
||||||
return acl_env;
|
return acl_env;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool AscendGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) {
|
||||||
|
// for Ascend, only support kAscend and kAscend910
|
||||||
|
if (device_type != kAscend && device_type != kAscend910) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return IsAscend910Soc();
|
||||||
|
}
|
||||||
|
|
||||||
std::map<uint32_t, std::weak_ptr<AscendGraphImpl::MsEnvGuard>> AscendGraphImpl::MsEnvGuard::global_ms_env_;
|
std::map<uint32_t, std::weak_ptr<AscendGraphImpl::MsEnvGuard>> AscendGraphImpl::MsEnvGuard::global_ms_env_;
|
||||||
std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_;
|
std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_;
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,7 @@ class AscendGraphImpl : public GraphCell::GraphImpl {
|
||||||
Status Load(uint32_t device_id) override;
|
Status Load(uint32_t device_id) override;
|
||||||
std::vector<MSTensor> GetInputs() override;
|
std::vector<MSTensor> GetInputs() override;
|
||||||
std::vector<MSTensor> GetOutputs() override;
|
std::vector<MSTensor> GetOutputs() override;
|
||||||
|
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class MsEnvGuard;
|
class MsEnvGuard;
|
||||||
|
|
|
@ -26,7 +26,7 @@
|
||||||
#include "runtime/device/gpu/cuda_driver.h"
|
#include "runtime/device/gpu/cuda_driver.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl);
|
API_FACTORY_REG(GraphCell::GraphImpl, GPUGraphImpl);
|
||||||
|
|
||||||
GPUGraphImpl::GPUGraphImpl()
|
GPUGraphImpl::GPUGraphImpl()
|
||||||
: session_impl_(nullptr),
|
: session_impl_(nullptr),
|
||||||
|
@ -291,4 +291,6 @@ std::vector<MSTensor> GPUGraphImpl::GetOutputs() {
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool GPUGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) { return device_type == kGPU; }
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -37,6 +37,8 @@ class GPUGraphImpl : public GraphCell::GraphImpl {
|
||||||
std::vector<MSTensor> GetInputs() override;
|
std::vector<MSTensor> GetInputs() override;
|
||||||
std::vector<MSTensor> GetOutputs() override;
|
std::vector<MSTensor> GetOutputs() override;
|
||||||
|
|
||||||
|
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status InitEnv();
|
Status InitEnv();
|
||||||
Status FinalizeEnv();
|
Status FinalizeEnv();
|
||||||
|
|
|
@ -42,6 +42,8 @@ class GraphCell::GraphImpl {
|
||||||
virtual std::vector<MSTensor> GetInputs() = 0;
|
virtual std::vector<MSTensor> GetInputs() = 0;
|
||||||
virtual std::vector<MSTensor> GetOutputs() = 0;
|
virtual std::vector<MSTensor> GetOutputs() = 0;
|
||||||
|
|
||||||
|
virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::shared_ptr<Graph> graph_;
|
std::shared_ptr<Graph> graph_;
|
||||||
std::shared_ptr<Context> graph_context_;
|
std::shared_ptr<Context> graph_context_;
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include "include/api/context.h"
|
#include "include/api/context.h"
|
||||||
#include "cxx_api/factory.h"
|
#include "cxx_api/factory.h"
|
||||||
#include "cxx_api/graph/acl/acl_env_guard.h"
|
#include "cxx_api/graph/acl/acl_env_guard.h"
|
||||||
#include "acl/acl_base.h"
|
#include "cxx_api/acl_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
Status AclModel::Build() {
|
Status AclModel::Build() {
|
||||||
|
@ -112,7 +112,7 @@ Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<s
|
||||||
|
|
||||||
if (model_context_ == nullptr) {
|
if (model_context_ == nullptr) {
|
||||||
model_context_ = std::make_shared<Context>();
|
model_context_ = std::make_shared<Context>();
|
||||||
model_context_->MutableDeviceInfo().emplace_back(std::make_shared<Ascend310DeviceInfo>());
|
model_context_->MutableDeviceInfo().emplace_back(std::make_shared<AscendDeviceInfo>());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string input_shape_option;
|
std::string input_shape_option;
|
||||||
|
@ -139,7 +139,7 @@ Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<s
|
||||||
MS_LOG(ERROR) << "Invalid model context, only single device info is supported.";
|
MS_LOG(ERROR) << "Invalid model context, only single device info is supported.";
|
||||||
return kMCInvalidArgs;
|
return kMCInvalidArgs;
|
||||||
}
|
}
|
||||||
auto ascend310_info = device_infos[0]->Cast<Ascend310DeviceInfo>();
|
auto ascend310_info = device_infos[0]->Cast<AscendDeviceInfo>();
|
||||||
MS_EXCEPTION_IF_NULL(ascend310_info);
|
MS_EXCEPTION_IF_NULL(ascend310_info);
|
||||||
ascend310_info->SetInputShape(input_shape_option);
|
ascend310_info->SetInputShape(input_shape_option);
|
||||||
auto graph_cell_bak = std::move(graph_cell_);
|
auto graph_cell_bak = std::move(graph_cell_);
|
||||||
|
@ -163,16 +163,15 @@ std::vector<MSTensor> AclModel::GetOutputs() {
|
||||||
return graph_cell_->GetOutputs();
|
return graph_cell_->GetOutputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AclModel::CheckModelSupport(enum ModelType model_type) {
|
bool AclModel::CheckDeviceSupport(mindspore::DeviceType device_type) {
|
||||||
const char *soc_name_c = aclrtGetSocName();
|
// for Ascend, only support kAscend and kAscend310
|
||||||
if (soc_name_c == nullptr) {
|
if (device_type != kAscend && device_type != kAscend310) {
|
||||||
return false;
|
|
||||||
}
|
|
||||||
std::string soc_name(soc_name_c);
|
|
||||||
if (soc_name.find("910") != std::string::npos) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
return IsAscendNo910Soc();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AclModel::CheckModelSupport(enum ModelType model_type) {
|
||||||
static const std::set<ModelType> kSupportedModelMap = {kMindIR, kOM};
|
static const std::set<ModelType> kSupportedModelMap = {kMindIR, kOM};
|
||||||
auto iter = kSupportedModelMap.find(model_type);
|
auto iter = kSupportedModelMap.find(model_type);
|
||||||
if (iter == kSupportedModelMap.end()) {
|
if (iter == kSupportedModelMap.end()) {
|
||||||
|
|
|
@ -43,6 +43,7 @@ class AclModel : public ModelImpl {
|
||||||
std::vector<MSTensor> GetInputs() override;
|
std::vector<MSTensor> GetInputs() override;
|
||||||
std::vector<MSTensor> GetOutputs() override;
|
std::vector<MSTensor> GetOutputs() override;
|
||||||
|
|
||||||
|
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
|
||||||
bool CheckModelSupport(enum ModelType model_type) override;
|
bool CheckModelSupport(enum ModelType model_type) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -29,7 +29,7 @@
|
||||||
#include "cxx_api/model/acl/acl_vm/acl_vm.h"
|
#include "cxx_api/model/acl/acl_vm/acl_vm.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
API_FACTORY_REG(ModelImpl, Ascend310, AclModelMulti);
|
API_FACTORY_REG(ModelImpl, AclModelMulti);
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
std::map<DataType, size_t> kDtypeMap = {
|
std::map<DataType, size_t> kDtypeMap = {
|
||||||
|
|
|
@ -33,7 +33,7 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
|
||||||
if (device_infos.size() != 1) {
|
if (device_infos.size() != 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto ascend310_info = device_infos[0]->Cast<Ascend310DeviceInfo>();
|
auto ascend310_info = device_infos[0]->Cast<AscendDeviceInfo>();
|
||||||
if (ascend310_info == nullptr) {
|
if (ascend310_info == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,19 +20,6 @@
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace {
|
|
||||||
std::string GetDeviceTypeString(enum DeviceType type) {
|
|
||||||
static const std::map<enum DeviceType, std::string> kDeviceTypeStrs = {
|
|
||||||
{kCPU, "CPU"}, {kGPU, "GPU"}, {kKirinNPU, "KirinGPU"}, {kAscend910, "Ascend910"}, {kAscend310, "Ascend310"},
|
|
||||||
};
|
|
||||||
auto iter = kDeviceTypeStrs.find(type);
|
|
||||||
if (iter != kDeviceTypeStrs.end()) {
|
|
||||||
return iter->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
return "InvalidDeviceType" + std::to_string(static_cast<int>(type));
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_context,
|
Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_context,
|
||||||
const std::shared_ptr<TrainCfg> &) {
|
const std::shared_ptr<TrainCfg> &) {
|
||||||
if (graph_cell.GetGraph() == nullptr) {
|
if (graph_cell.GetGraph() == nullptr) {
|
||||||
|
@ -50,7 +37,7 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_
|
||||||
return kMCInvalidInput;
|
return kMCInvalidInput;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string device_target = GetDeviceTypeString(device_info[0]->GetDeviceType());
|
auto device_target = device_info[0]->GetDeviceType();
|
||||||
impl_ = Factory<ModelImpl>::Instance().Create(device_target);
|
impl_ = Factory<ModelImpl>::Instance().Create(device_target);
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Create session type " << device_target << " failed";
|
MS_LOG(ERROR) << "Create session type " << device_target << " failed";
|
||||||
|
@ -175,16 +162,10 @@ Model::Model() : impl_(nullptr) {}
|
||||||
Model::~Model() {}
|
Model::~Model() {}
|
||||||
|
|
||||||
bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) {
|
bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) {
|
||||||
std::string device_type_str = GetDeviceTypeString(device_type);
|
auto check_model = Factory<ModelImpl>::Instance().Create(device_type);
|
||||||
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type_str)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto check_model = Factory<ModelImpl>::Instance().Create(device_type_str);
|
|
||||||
if (check_model == nullptr) {
|
if (check_model == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return check_model->CheckModelSupport(model_type);
|
return check_model->CheckModelSupport(model_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,8 @@ class ModelImpl {
|
||||||
virtual std::vector<MSTensor> GetInputs() = 0;
|
virtual std::vector<MSTensor> GetInputs() = 0;
|
||||||
virtual std::vector<MSTensor> GetOutputs() = 0;
|
virtual std::vector<MSTensor> GetOutputs() = 0;
|
||||||
|
|
||||||
virtual bool CheckModelSupport(enum ModelType model_type) { return false; }
|
virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0;
|
||||||
|
virtual bool CheckModelSupport(enum ModelType model_type) = 0;
|
||||||
|
|
||||||
virtual Status Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
virtual Status Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
||||||
|
|
||||||
|
|
|
@ -20,14 +20,13 @@
|
||||||
#include "include/api/context.h"
|
#include "include/api/context.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "cxx_api/factory.h"
|
#include "cxx_api/factory.h"
|
||||||
|
#if ENABLE_D
|
||||||
|
#include "cxx_api/acl_utils.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
// mindspore-serving check current package for version check with ModelImpl factory.
|
// mindspore-serving check current package for version check with ModelImpl factory.
|
||||||
#if ENABLE_D
|
API_FACTORY_REG(ModelImpl, MsModel);
|
||||||
API_FACTORY_REG(ModelImpl, Ascend910, MsModel);
|
|
||||||
#elif ENABLE_GPU
|
|
||||||
API_FACTORY_REG(ModelImpl, GPU, MsModel);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
static std::string GenerateShapeKey(const std::vector<std::vector<int64_t>> &dims) {
|
static std::string GenerateShapeKey(const std::vector<std::vector<int64_t>> &dims) {
|
||||||
std::string shape_key;
|
std::string shape_key;
|
||||||
|
@ -171,18 +170,23 @@ uint32_t MsModel::GetDeviceID() const {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MsModel::CheckModelSupport(enum ModelType model_type) {
|
bool MsModel::CheckDeviceSupport(enum DeviceType device_type) {
|
||||||
#if ENABLE_D
|
#if ENABLE_D
|
||||||
const char *soc_name_c = aclrtGetSocName();
|
// for Ascend, only support kAscend or kAscend910
|
||||||
if (soc_name_c == nullptr) {
|
if (device_type != kAscend && device_type != kAscend910) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
std::string soc_name(soc_name_c);
|
return IsAscend910Soc();
|
||||||
if (soc_name.find("910") == std::string::npos) {
|
#else
|
||||||
|
// otherwise, only support GPU
|
||||||
|
if (device_type != kGPU) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
return true;
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MsModel::CheckModelSupport(mindspore::ModelType model_type) {
|
||||||
static const std::set<ModelType> kSupportedModelMap = {kMindIR};
|
static const std::set<ModelType> kSupportedModelMap = {kMindIR};
|
||||||
auto iter = kSupportedModelMap.find(model_type);
|
auto iter = kSupportedModelMap.find(model_type);
|
||||||
if (iter == kSupportedModelMap.end()) {
|
if (iter == kSupportedModelMap.end()) {
|
||||||
|
|
|
@ -44,6 +44,7 @@ class MsModel : public ModelImpl {
|
||||||
std::vector<MSTensor> GetInputs() override;
|
std::vector<MSTensor> GetInputs() override;
|
||||||
std::vector<MSTensor> GetOutputs() override;
|
std::vector<MSTensor> GetOutputs() override;
|
||||||
|
|
||||||
|
bool CheckDeviceSupport(mindspore::DeviceType device_type) override;
|
||||||
bool CheckModelSupport(enum ModelType model_type) override;
|
bool CheckModelSupport(enum ModelType model_type) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -1,18 +1,18 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
* You may obtain a copy of the License at
|
* You may obtain a copy of the License at
|
||||||
*
|
*
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
|
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||||
#define MINDSPORE_INCLUDE_API_CONTEXT_H
|
#define MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||||
|
|
||||||
|
@ -25,212 +25,431 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
enum DeviceType {
|
enum DeviceType {
|
||||||
kCPU = 0,
|
kCPU = 0,
|
||||||
kGPU,
|
kGPU,
|
||||||
kKirinNPU,
|
kKirinNPU,
|
||||||
kAscend910,
|
kAscend,
|
||||||
kAscend310,
|
kAscend910,
|
||||||
// add new type here
|
kAscend310,
|
||||||
kInvalidDeviceType = 100,
|
// add new type here
|
||||||
|
kInvalidDeviceType = 100,
|
||||||
};
|
};
|
||||||
|
|
||||||
class Allocator;
|
class Allocator;
|
||||||
|
class Delegate;
|
||||||
class DeviceInfoContext;
|
class DeviceInfoContext;
|
||||||
|
|
||||||
|
/// \brief Context is used to store environment variables during execution.
|
||||||
class MS_API Context {
|
class MS_API Context {
|
||||||
public:
|
public:
|
||||||
Context();
|
struct Data;
|
||||||
~Context() = default;
|
Context();
|
||||||
|
~Context() = default;
|
||||||
|
|
||||||
void SetThreadNum(int32_t thread_num);
|
/// \brief Set the number of threads at runtime. Only valid for Lite.
|
||||||
int32_t GetThreadNum() const;
|
///
|
||||||
|
/// \param[in] thread_num the number of threads at runtime.
|
||||||
|
void SetThreadNum(int32_t thread_num);
|
||||||
|
|
||||||
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
|
/// \brief Get the current thread number setting. Only valid for Lite.
|
||||||
std::shared_ptr<Allocator> GetAllocator() const;
|
///
|
||||||
|
/// \return The current thread number setting.
|
||||||
|
int32_t GetThreadNum() const;
|
||||||
|
|
||||||
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
|
/// \brief Set the thread affinity to CPU cores. Only valid for Lite.
|
||||||
|
///
|
||||||
|
/// \param[in] mode: 0: no affinities, 1: big cores first, 2: little cores first
|
||||||
|
void SetThreadAffinity(int mode);
|
||||||
|
|
||||||
private:
|
/// \brief Get the thread affinity of CPU cores. Only valid for Lite.
|
||||||
struct Data;
|
///
|
||||||
std::shared_ptr<Data> data_;
|
/// \return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first
|
||||||
|
int GetThreadAffinityMode() const;
|
||||||
|
|
||||||
|
/// \brief Set the thread lists to CPU cores. Only valid for Lite.
|
||||||
|
///
|
||||||
|
/// \note If core_list and mode are set by SetThreadAffinity at the same time, the core_list is effective, but the
|
||||||
|
/// mode is not effective.
|
||||||
|
///
|
||||||
|
/// \param[in] core_list: a vector of thread core lists.
|
||||||
|
void SetThreadAffinity(const std::vector<int> &core_list);
|
||||||
|
|
||||||
|
/// \brief Get the thread lists of CPU cores. Only valid for Lite.
|
||||||
|
///
|
||||||
|
/// \return core_list: a vector of thread core lists.
|
||||||
|
std::vector<int32_t> GetThreadAffinityCoreList() const;
|
||||||
|
|
||||||
|
/// \brief Set the status whether to perform model inference or training in parallel. Only valid for Lite.
|
||||||
|
///
|
||||||
|
/// \param[in] is_parallel: true, parallel; false, not in parallel.
|
||||||
|
void SetEnableParallel(bool is_parallel);
|
||||||
|
|
||||||
|
/// \brief Get the status whether to perform model inference or training in parallel. Only valid for Lite.
|
||||||
|
///
|
||||||
|
/// \return Bool value that indicates whether in parallel.
|
||||||
|
bool GetEnableParallel() const;
|
||||||
|
|
||||||
|
/// \brief Set Delegate to access third-party AI framework. Only valid for Lite.
|
||||||
|
///
|
||||||
|
/// \param[in] Pointer to the custom delegate.
|
||||||
|
void SetDelegate(const std::shared_ptr<Delegate> &delegate);
|
||||||
|
|
||||||
|
/// \brief Get the delegate of the third-party AI framework. Only valid for Lite.
|
||||||
|
///
|
||||||
|
/// \return Pointer to the custom delegate.
|
||||||
|
std::shared_ptr<Delegate> GetDelegate() const;
|
||||||
|
|
||||||
|
/// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports
|
||||||
|
/// heterogeneous scenarios with multiple members in the vector.
|
||||||
|
///
|
||||||
|
/// \return Mutable reference of DeviceInfoContext vector in this context.
|
||||||
|
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<Data> data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// \brief DeviceInfoContext defines different device contexts.
|
||||||
class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
|
class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
|
||||||
public:
|
public:
|
||||||
struct Data;
|
struct Data;
|
||||||
|
|
||||||
DeviceInfoContext();
|
DeviceInfoContext();
|
||||||
virtual ~DeviceInfoContext() = default;
|
virtual ~DeviceInfoContext() = default;
|
||||||
virtual enum DeviceType GetDeviceType() const = 0;
|
|
||||||
|
|
||||||
template <class T>
|
/// \brief Get the type of this DeviceInfoContext.
|
||||||
std::shared_ptr<T> Cast() {
|
///
|
||||||
static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
|
/// \return Type of this DeviceInfoContext.
|
||||||
if (GetDeviceType() != T().GetDeviceType()) {
|
virtual enum DeviceType GetDeviceType() const = 0;
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::static_pointer_cast<T>(shared_from_this());
|
/// \brief A similar function to RTTI is provided when the -fno-rtti compilation option is turned on, which converts
|
||||||
}
|
/// DeviceInfoContext to a shared pointer of type T, and returns nullptr if the conversion fails.
|
||||||
|
///
|
||||||
|
/// \param T Type
|
||||||
|
/// \return A pointer of type T after conversion. If the conversion fails, it will be nullptr.
|
||||||
|
template <class T>
|
||||||
|
std::shared_ptr<T> Cast() {
|
||||||
|
static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
|
||||||
|
if (GetDeviceType() != T().GetDeviceType()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
return std::static_pointer_cast<T>(shared_from_this());
|
||||||
std::shared_ptr<Data> data_;
|
}
|
||||||
|
/// \brief obtain provider's name
|
||||||
|
///
|
||||||
|
/// \return provider's name.
|
||||||
|
std::string GetProvider() const;
|
||||||
|
/// \brief set provider's name.
|
||||||
|
///
|
||||||
|
/// \param[in] provider define the provider's name.
|
||||||
|
|
||||||
|
void SetProvider(const std::string &provider);
|
||||||
|
/// \brief obtain provider's device type.
|
||||||
|
///
|
||||||
|
/// \return provider's device type.
|
||||||
|
|
||||||
|
std::string GetProviderDevice() const;
|
||||||
|
/// \brief set provider's device type.
|
||||||
|
///
|
||||||
|
/// \param[in] device define the provider's device type.EG: CPU.
|
||||||
|
void SetProviderDevice(const std::string &device);
|
||||||
|
|
||||||
|
/// \brief set memory allocator.
|
||||||
|
///
|
||||||
|
/// \param[in] allocator define the memory allocator which can be defined by user.
|
||||||
|
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
|
||||||
|
|
||||||
|
/// \brief obtain memory allocator.
|
||||||
|
///
|
||||||
|
/// \return memory allocator.
|
||||||
|
std::shared_ptr<Allocator> GetAllocator() const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::shared_ptr<Data> data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the CPU. This option is only valid
|
||||||
|
/// for MindSpore Lite.
|
||||||
class MS_API CPUDeviceInfo : public DeviceInfoContext {
|
class MS_API CPUDeviceInfo : public DeviceInfoContext {
|
||||||
public:
|
public:
|
||||||
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
|
/// \brief Get the type of this DeviceInfoContext.
|
||||||
|
///
|
||||||
|
/// \return Type of this DeviceInfoContext.
|
||||||
|
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
|
||||||
|
|
||||||
/// \brief Set the thread affinity to CPU cores.
|
/// \brief Set enables to perform the float16 inference
|
||||||
///
|
///
|
||||||
/// \param mode: 0: no affinities, 1: big cores first, 2: little cores first
|
/// \param[in] is_fp16 Enable float16 inference or not.
|
||||||
void SetThreadAffinity(int mode);
|
void SetEnableFP16(bool is_fp16);
|
||||||
int GetThreadAffinity() const;
|
|
||||||
void SetEnableFP16(bool is_fp16);
|
/// \brief Get enables to perform the float16 inference
|
||||||
bool GetEnableFP16() const;
|
///
|
||||||
|
/// \return Whether enable float16 inference.
|
||||||
|
bool GetEnableFP16() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the NPU. This option is only valid
|
||||||
|
/// for MindSpore Lite.
|
||||||
class MS_API KirinNPUDeviceInfo : public DeviceInfoContext {
|
class MS_API KirinNPUDeviceInfo : public DeviceInfoContext {
|
||||||
public:
|
public:
|
||||||
enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
|
/// \brief Get the type of this DeviceInfoContext.
|
||||||
|
///
|
||||||
|
/// \return Type of this DeviceInfoContext.
|
||||||
|
enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
|
||||||
|
|
||||||
void SetFrequency(int frequency);
|
/// \brief Set the NPU frequency.
|
||||||
int GetFrequency() const;
|
///
|
||||||
|
/// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme
|
||||||
|
/// performance), default as 3.
|
||||||
|
void SetFrequency(int frequency);
|
||||||
|
|
||||||
|
/// \brief Get the NPU frequency.
|
||||||
|
///
|
||||||
|
/// \return NPU frequency
|
||||||
|
int GetFrequency() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the GPU.
|
||||||
class MS_API GPUDeviceInfo : public DeviceInfoContext {
|
class MS_API GPUDeviceInfo : public DeviceInfoContext {
|
||||||
public:
|
public:
|
||||||
enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; };
|
/// \brief Get the type of this DeviceInfoContext.
|
||||||
|
///
|
||||||
|
/// \return Type of this DeviceInfoContext.
|
||||||
|
enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; };
|
||||||
|
|
||||||
void SetDeviceID(uint32_t device_id);
|
/// \brief Set device id.
|
||||||
uint32_t GetDeviceID() const;
|
///
|
||||||
|
/// \param[in] device_id The device id.
|
||||||
|
void SetDeviceID(uint32_t device_id);
|
||||||
|
|
||||||
void SetGpuTrtInferMode(bool gpu_trt_infer_mode);
|
/// \brief Get the device id.
|
||||||
bool GetGpuTrtInferMode() const;
|
///
|
||||||
|
/// \return The device id.
|
||||||
|
uint32_t GetDeviceID() const;
|
||||||
|
|
||||||
void SetEnableFP16(bool is_fp16);
|
/// \brief Get the distribution rank id.
|
||||||
bool GetEnableFP16() const;
|
///
|
||||||
|
/// \return The device id.
|
||||||
|
int GetRankID() const;
|
||||||
|
|
||||||
|
/// \brief Get the distribution group size.
|
||||||
|
///
|
||||||
|
/// \return The device id.
|
||||||
|
int GetGroupSize() const;
|
||||||
|
|
||||||
|
/// \brief Set the precision mode.
|
||||||
|
///
|
||||||
|
/// \param[in] precision_mode Optional "origin", "fp16". "origin" is set as default.
|
||||||
|
inline void SetPrecisionMode(const std::string &precision_mode);
|
||||||
|
|
||||||
|
/// \brief Get the precision mode.
|
||||||
|
///
|
||||||
|
/// \return The precision mode.
|
||||||
|
inline std::string GetPrecisionMode() const;
|
||||||
|
|
||||||
|
/// \brief Set enables to perform the float16 inference
|
||||||
|
///
|
||||||
|
/// \param[in] is_fp16 Enable float16 inference or not.
|
||||||
|
void SetEnableFP16(bool is_fp16);
|
||||||
|
|
||||||
|
/// \brief Get enables to perform the float16 inference
|
||||||
|
///
|
||||||
|
/// \return Whether enable float16 inference.
|
||||||
|
bool GetEnableFP16() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void SetPrecisionMode(const std::vector<char> &precision_mode);
|
||||||
|
std::vector<char> GetPrecisionModeChar() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
|
void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
||||||
public:
|
SetPrecisionMode(StringToChar(precision_mode));
|
||||||
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
|
}
|
||||||
|
std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
||||||
|
|
||||||
void SetDeviceID(uint32_t device_id);
|
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is
|
||||||
uint32_t GetDeviceID() const;
|
/// invalid for MindSpore Lite.
|
||||||
|
class MS_API AscendDeviceInfo : public DeviceInfoContext {
|
||||||
|
public:
|
||||||
|
/// \brief Get the type of this DeviceInfoContext.
|
||||||
|
///
|
||||||
|
/// \return Type of this DeviceInfoContext.
|
||||||
|
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; };
|
||||||
|
|
||||||
|
/// \brief Set device id.
|
||||||
|
///
|
||||||
|
/// \param[in] device_id The device id.
|
||||||
|
void SetDeviceID(uint32_t device_id);
|
||||||
|
|
||||||
|
/// \brief Get the device id.
|
||||||
|
///
|
||||||
|
/// \return The device id.
|
||||||
|
uint32_t GetDeviceID() const;
|
||||||
|
|
||||||
|
/// \brief Set AIPP configuration file path.
|
||||||
|
///
|
||||||
|
/// \param[in] cfg_path AIPP configuration file path.
|
||||||
|
inline void SetInsertOpConfigPath(const std::string &cfg_path);
|
||||||
|
|
||||||
|
/// \brief Get AIPP configuration file path.
|
||||||
|
///
|
||||||
|
/// \return AIPP configuration file path.
|
||||||
|
inline std::string GetInsertOpConfigPath() const;
|
||||||
|
|
||||||
|
/// \brief Set format of model inputs.
|
||||||
|
///
|
||||||
|
/// \param[in] format Optional "NCHW", "NHWC", etc.
|
||||||
|
inline void SetInputFormat(const std::string &format);
|
||||||
|
|
||||||
|
/// \brief Get format of model inputs.
|
||||||
|
///
|
||||||
|
/// \return The format of model inputs.
|
||||||
|
inline std::string GetInputFormat() const;
|
||||||
|
|
||||||
|
/// \brief Set shape of model inputs.
|
||||||
|
///
|
||||||
|
/// \param[in] shape e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1".
|
||||||
|
inline void SetInputShape(const std::string &shape);
|
||||||
|
|
||||||
|
/// \brief Get shape of model inputs.
|
||||||
|
///
|
||||||
|
/// \return The shape of model inputs.
|
||||||
|
inline std::string GetInputShape() const;
|
||||||
|
|
||||||
|
/// \brief Set shape of model inputs.
|
||||||
|
///
|
||||||
|
/// \param[in] shape e.g. {{1, {1,2,3,4}}, {2, {4,3,2,1}}} means the first input shape 1,2,3,4 and the second input
|
||||||
|
/// shape 4,3,2,1.
|
||||||
|
void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
|
||||||
|
|
||||||
|
/// \brief Get shape of model inputs.
|
||||||
|
///
|
||||||
|
/// \return The shape of model inputs.
|
||||||
|
std::map<int, std::vector<int>> GetInputShapeMap() const;
|
||||||
|
|
||||||
|
void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
|
||||||
|
inline std::string GetDynamicBatchSize() const;
|
||||||
|
|
||||||
|
/// \brief Set the dynamic image size of model inputs.
|
||||||
|
///
|
||||||
|
/// \param[in] image size hw e.g. "66,88;32,64" means h1:66,w1:88; h2:32,w2:64.
|
||||||
|
inline void SetDynamicImageSize(const std::string &dynamic_image_size);
|
||||||
|
|
||||||
|
/// \brief Get dynamic image size of model inputs.
|
||||||
|
///
|
||||||
|
/// \return The image size of model inputs.
|
||||||
|
inline std::string GetDynamicImageSize() const;
|
||||||
|
|
||||||
|
/// \brief Set type of model outputs.
|
||||||
|
///
|
||||||
|
/// \param[in] output_type FP32, UINT8 or FP16, default as FP32.
|
||||||
|
void SetOutputType(enum DataType output_type);
|
||||||
|
|
||||||
|
/// \brief Get type of model outputs.
|
||||||
|
///
|
||||||
|
/// \return The set type of model outputs.
|
||||||
|
enum DataType GetOutputType() const;
|
||||||
|
|
||||||
|
/// \brief Set precision mode of model.
|
||||||
|
///
|
||||||
|
/// \param[in] precision_mode Optional "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" and
|
||||||
|
/// "allow_mix_precision", "force_fp16" is set as default
|
||||||
|
inline void SetPrecisionMode(const std::string &precision_mode);
|
||||||
|
|
||||||
|
/// \brief Get precision mode of model.
|
||||||
|
///
|
||||||
|
/// \return The set type of model outputs
|
||||||
|
inline std::string GetPrecisionMode() const;
|
||||||
|
|
||||||
|
/// \brief Set op select implementation mode.
|
||||||
|
///
|
||||||
|
/// \param[in] op_select_impl_mode Optional "high_performance" and "high_precision", "high_performance" is set as
|
||||||
|
/// default.
|
||||||
|
inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
|
||||||
|
|
||||||
|
/// \brief Get op select implementation mode.
|
||||||
|
///
|
||||||
|
/// \return The set op select implementation mode.
|
||||||
|
inline std::string GetOpSelectImplMode() const;
|
||||||
|
|
||||||
|
inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
|
||||||
|
inline std::string GetFusionSwitchConfigPath() const;
|
||||||
|
|
||||||
|
// Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize"
|
||||||
|
inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode);
|
||||||
|
inline std::string GetBufferOptimizeMode() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
|
||||||
|
std::vector<char> GetInsertOpConfigPathChar() const;
|
||||||
|
|
||||||
|
void SetInputFormat(const std::vector<char> &format);
|
||||||
|
std::vector<char> GetInputFormatChar() const;
|
||||||
|
|
||||||
|
void SetInputShape(const std::vector<char> &shape);
|
||||||
|
std::vector<char> GetInputShapeChar() const;
|
||||||
|
|
||||||
|
std::vector<char> GetDynamicBatchSizeChar() const;
|
||||||
|
|
||||||
|
void SetDynamicImageSize(const std::vector<char> &dynamic_image_size);
|
||||||
|
std::vector<char> GetDynamicImageSizeChar() const;
|
||||||
|
|
||||||
|
void SetPrecisionMode(const std::vector<char> &precision_mode);
|
||||||
|
std::vector<char> GetPrecisionModeChar() const;
|
||||||
|
|
||||||
|
void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
|
||||||
|
std::vector<char> GetOpSelectImplModeChar() const;
|
||||||
|
|
||||||
|
void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
|
||||||
|
std::vector<char> GetFusionSwitchConfigPathChar() const;
|
||||||
|
|
||||||
|
void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode);
|
||||||
|
std::vector<char> GetBufferOptimizeModeChar() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
|
using Ascend310DeviceInfo = AscendDeviceInfo;
|
||||||
public:
|
using Ascend910DeviceInfo = AscendDeviceInfo;
|
||||||
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };
|
|
||||||
|
|
||||||
void SetDeviceID(uint32_t device_id);
|
void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
|
||||||
uint32_t GetDeviceID() const;
|
SetInsertOpConfigPath(StringToChar(cfg_path));
|
||||||
|
|
||||||
inline void SetDumpConfigPath(const std::string &cfg_path);
|
|
||||||
inline std::string GetDumpConfigPath() const;
|
|
||||||
|
|
||||||
// aipp config file
|
|
||||||
inline void SetInsertOpConfigPath(const std::string &cfg_path);
|
|
||||||
inline std::string GetInsertOpConfigPath() const;
|
|
||||||
|
|
||||||
// nchw or nhwc
|
|
||||||
inline void SetInputFormat(const std::string &format);
|
|
||||||
inline std::string GetInputFormat() const;
|
|
||||||
|
|
||||||
// Mandatory while dynamic batch: e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1"
|
|
||||||
inline void SetInputShape(const std::string &shape);
|
|
||||||
inline std::string GetInputShape() const;
|
|
||||||
|
|
||||||
void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
|
|
||||||
std::map<int, std::vector<int>> GetInputShapeMap() const;
|
|
||||||
|
|
||||||
void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
|
|
||||||
inline std::string GetDynamicBatchSize() const;
|
|
||||||
|
|
||||||
// FP32, UINT8 or FP16, default as FP32
|
|
||||||
void SetOutputType(enum DataType output_type);
|
|
||||||
enum DataType GetOutputType() const;
|
|
||||||
|
|
||||||
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype", default as "force_fp16"
|
|
||||||
inline void SetPrecisionMode(const std::string &precision_mode);
|
|
||||||
inline std::string GetPrecisionMode() const;
|
|
||||||
|
|
||||||
// Optional "high_performance" and "high_precision", "high_performance" is set as default
|
|
||||||
inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
|
|
||||||
inline std::string GetOpSelectImplMode() const;
|
|
||||||
|
|
||||||
inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
|
|
||||||
inline std::string GetFusionSwitchConfigPath() const;
|
|
||||||
|
|
||||||
// Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize"
|
|
||||||
inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode);
|
|
||||||
inline std::string GetBufferOptimizeMode() const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
void SetDumpConfigPath(const std::vector<char> &cfg_path);
|
|
||||||
std::vector<char> GetDumpConfigPathChar() const;
|
|
||||||
|
|
||||||
void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
|
|
||||||
std::vector<char> GetInsertOpConfigPathChar() const;
|
|
||||||
|
|
||||||
void SetInputFormat(const std::vector<char> &format);
|
|
||||||
std::vector<char> GetInputFormatChar() const;
|
|
||||||
|
|
||||||
void SetInputShape(const std::vector<char> &shape);
|
|
||||||
std::vector<char> GetInputShapeChar() const;
|
|
||||||
|
|
||||||
std::vector<char> GetDynamicBatchSizeChar() const;
|
|
||||||
|
|
||||||
void SetPrecisionMode(const std::vector<char> &precision_mode);
|
|
||||||
std::vector<char> GetPrecisionModeChar() const;
|
|
||||||
|
|
||||||
void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
|
|
||||||
std::vector<char> GetOpSelectImplModeChar() const;
|
|
||||||
|
|
||||||
void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
|
|
||||||
std::vector<char> GetFusionSwitchConfigPathChar() const;
|
|
||||||
|
|
||||||
void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode);
|
|
||||||
std::vector<char> GetBufferOptimizeModeChar() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetDumpConfigPath(const std::string &cfg_path) { SetDumpConfigPath(StringToChar(cfg_path)); }
|
|
||||||
std::string Ascend310DeviceInfo::GetDumpConfigPath() const { return CharToString(GetDumpConfigPathChar()); }
|
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
|
|
||||||
SetInsertOpConfigPath(StringToChar(cfg_path));
|
|
||||||
}
|
}
|
||||||
std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
|
std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
|
void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
|
||||||
std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
|
std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
|
void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
|
||||||
std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
|
std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
|
||||||
|
|
||||||
std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
|
std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) {
|
||||||
SetPrecisionMode(StringToChar(precision_mode));
|
SetDynamicImageSize(StringToChar(dynamic_image_size));
|
||||||
}
|
|
||||||
std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
|
|
||||||
SetOpSelectImplMode(StringToChar(op_select_impl_mode));
|
|
||||||
}
|
|
||||||
std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
|
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
|
|
||||||
SetFusionSwitchConfigPath(StringToChar(cfg_path));
|
|
||||||
}
|
|
||||||
std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const {
|
|
||||||
return CharToString(GetFusionSwitchConfigPathChar());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
|
std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); }
|
||||||
SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
|
|
||||||
|
void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
||||||
|
SetPrecisionMode(StringToChar(precision_mode));
|
||||||
}
|
}
|
||||||
std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
|
std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
||||||
|
|
||||||
|
void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
|
||||||
|
SetOpSelectImplMode(StringToChar(op_select_impl_mode));
|
||||||
|
}
|
||||||
|
std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
|
||||||
|
|
||||||
|
void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
|
||||||
|
SetFusionSwitchConfigPath(StringToChar(cfg_path));
|
||||||
|
}
|
||||||
|
std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const {
|
||||||
|
return CharToString(GetFusionSwitchConfigPathChar());
|
||||||
|
}
|
||||||
|
|
||||||
|
void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
|
||||||
|
SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
|
||||||
|
}
|
||||||
|
std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H
|
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||||
|
|
|
@ -317,13 +317,7 @@ std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
void AscendDeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||||
uint32_t Ascend910DeviceInfo::GetDeviceID() const {
|
|
||||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
|
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return;
|
return;
|
||||||
|
@ -331,7 +325,7 @@ void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||||
data_->params[kModelOptionAscend310DeviceID] = device_id;
|
data_->params[kModelOptionAscend310DeviceID] = device_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t Ascend310DeviceInfo::GetDeviceID() const {
|
uint32_t AscendDeviceInfo::GetDeviceID() const {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -339,14 +333,14 @@ uint32_t Ascend310DeviceInfo::GetDeviceID() const {
|
||||||
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
|
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
|
void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
|
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
|
||||||
}
|
}
|
||||||
std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
|
std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return std::vector<char>();
|
return std::vector<char>();
|
||||||
|
@ -355,7 +349,7 @@ std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
|
void AscendDeviceInfo::SetInputFormat(const std::vector<char> &format) {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return;
|
return;
|
||||||
|
@ -363,7 +357,7 @@ void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
|
||||||
data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
|
data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
|
std::vector<char> AscendDeviceInfo::GetInputFormatChar() const {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return std::vector<char>();
|
return std::vector<char>();
|
||||||
|
@ -372,14 +366,14 @@ std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) {
|
void AscendDeviceInfo::SetInputShape(const std::vector<char> &shape) {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
|
data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
|
||||||
}
|
}
|
||||||
std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
|
std::vector<char> AscendDeviceInfo::GetInputShapeChar() const {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return std::vector<char>();
|
return std::vector<char>();
|
||||||
|
@ -388,7 +382,7 @@ std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
|
void AscendDeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return;
|
return;
|
||||||
|
@ -403,7 +397,7 @@ void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic
|
||||||
data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
|
data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
|
std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return std::vector<char>();
|
return std::vector<char>();
|
||||||
|
@ -412,7 +406,7 @@ std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) {
|
void AscendDeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return;
|
return;
|
||||||
|
@ -420,7 +414,7 @@ void Ascend310DeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_i
|
||||||
data_->params[kModelOptionAscend310DynamicImageSize] = CharToString(dynamic_image_size);
|
data_->params[kModelOptionAscend310DynamicImageSize] = CharToString(dynamic_image_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> Ascend310DeviceInfo::GetDynamicImageSizeChar() const {
|
std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return std::vector<char>();
|
return std::vector<char>();
|
||||||
|
@ -429,7 +423,7 @@ std::vector<char> Ascend310DeviceInfo::GetDynamicImageSizeChar() const {
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
void AscendDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return;
|
return;
|
||||||
|
@ -437,7 +431,7 @@ void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mo
|
||||||
data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
|
data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
|
std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return std::vector<char>();
|
return std::vector<char>();
|
||||||
|
@ -446,7 +440,7 @@ std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
|
void AscendDeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return;
|
return;
|
||||||
|
@ -454,7 +448,7 @@ void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select
|
||||||
data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
|
data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
|
std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return std::vector<char>();
|
return std::vector<char>();
|
||||||
|
@ -463,14 +457,14 @@ std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
|
void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
|
data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
|
||||||
}
|
}
|
||||||
std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
|
std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return std::vector<char>();
|
return std::vector<char>();
|
||||||
|
@ -479,7 +473,7 @@ std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
|
void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return;
|
return;
|
||||||
|
@ -487,7 +481,7 @@ void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>>
|
||||||
data_->params[kModelOptionAscend310InputShapeMap] = shape;
|
data_->params[kModelOptionAscend310InputShapeMap] = shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
|
std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return std::map<int, std::vector<int>>();
|
return std::map<int, std::vector<int>>();
|
||||||
|
@ -495,7 +489,7 @@ std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
|
||||||
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
|
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
|
void AscendDeviceInfo::SetOutputType(enum DataType output_type) {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return;
|
return;
|
||||||
|
@ -503,7 +497,7 @@ void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
|
||||||
data_->params[kModelOptionAscend310OutputType] = output_type;
|
data_->params[kModelOptionAscend310OutputType] = output_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum DataType Ascend310DeviceInfo::GetOutputType() const {
|
enum DataType AscendDeviceInfo::GetOutputType() const {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return DataType::kTypeUnknown;
|
return DataType::kTypeUnknown;
|
||||||
|
@ -511,7 +505,7 @@ enum DataType Ascend310DeviceInfo::GetOutputType() const {
|
||||||
return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
|
return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
|
void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return;
|
return;
|
||||||
|
@ -519,7 +513,7 @@ void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_
|
||||||
data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
|
data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const {
|
std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const {
|
||||||
if (data_ == nullptr) {
|
if (data_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
return std::vector<char>();
|
return std::vector<char>();
|
||||||
|
|
|
@ -104,7 +104,7 @@ lite::InnerContext *ContextUtils::Convert(Context *context) {
|
||||||
} else if (device->GetDeviceType() == kKirinNPU) {
|
} else if (device->GetDeviceType() == kKirinNPU) {
|
||||||
auto npu_context = device->Cast<KirinNPUDeviceInfo>();
|
auto npu_context = device->Cast<KirinNPUDeviceInfo>();
|
||||||
ret = AddNpuDevice(npu_context->GetFrequency(), inner_context.get());
|
ret = AddNpuDevice(npu_context->GetFrequency(), inner_context.get());
|
||||||
} else if (device->GetDeviceType() == kAscend310) {
|
} else if (device->GetDeviceType() == kAscend) {
|
||||||
ret = AddAscend310Device(inner_context.get(), device.get());
|
ret = AddAscend310Device(inner_context.get(), device.get());
|
||||||
}
|
}
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
|
|
|
@ -71,11 +71,11 @@ std::shared_ptr<mindspore::Context> Common::ContextAutoSet() {
|
||||||
auto context = std::make_shared<mindspore::Context>();
|
auto context = std::make_shared<mindspore::Context>();
|
||||||
|
|
||||||
if (device_target_str == "Ascend310") {
|
if (device_target_str == "Ascend310") {
|
||||||
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
auto ascend310_info = std::make_shared<mindspore::AscendDeviceInfo>();
|
||||||
ascend310_info->SetDeviceID(device_id);
|
ascend310_info->SetDeviceID(device_id);
|
||||||
context->MutableDeviceInfo().emplace_back(ascend310_info);
|
context->MutableDeviceInfo().emplace_back(ascend310_info);
|
||||||
} else if (device_target_str == "Ascend910") {
|
} else if (device_target_str == "Ascend910") {
|
||||||
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
auto ascend310_info = std::make_shared<mindspore::AscendDeviceInfo>();
|
||||||
ascend310_info->SetDeviceID(device_id);
|
ascend310_info->SetDeviceID(device_id);
|
||||||
context->MutableDeviceInfo().emplace_back(ascend310_info);
|
context->MutableDeviceInfo().emplace_back(ascend310_info);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -101,7 +101,7 @@ TEST_F(TestDE, TestDvpp) {
|
||||||
auto context = ContextAutoSet();
|
auto context = ContextAutoSet();
|
||||||
ASSERT_TRUE(context != nullptr);
|
ASSERT_TRUE(context != nullptr);
|
||||||
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
|
||||||
ASSERT_TRUE(ascend310_info != nullptr);
|
ASSERT_TRUE(ascend310_info != nullptr);
|
||||||
auto device_id = ascend310_info->GetDeviceID();
|
auto device_id = ascend310_info->GetDeviceID();
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ TEST_F(TestDE, TestDvppSinkMode) {
|
||||||
auto context = ContextAutoSet();
|
auto context = ContextAutoSet();
|
||||||
ASSERT_TRUE(context != nullptr);
|
ASSERT_TRUE(context != nullptr);
|
||||||
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
|
||||||
ASSERT_TRUE(ascend310_info != nullptr);
|
ASSERT_TRUE(ascend310_info != nullptr);
|
||||||
auto device_id = ascend310_info->GetDeviceID();
|
auto device_id = ascend310_info->GetDeviceID();
|
||||||
|
|
||||||
|
@ -202,7 +202,7 @@ TEST_F(TestDE, TestDvppDecodeResizeCropNormalize) {
|
||||||
auto context = ContextAutoSet();
|
auto context = ContextAutoSet();
|
||||||
ASSERT_TRUE(context != nullptr);
|
ASSERT_TRUE(context != nullptr);
|
||||||
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
|
||||||
ASSERT_TRUE(ascend310_info != nullptr);
|
ASSERT_TRUE(ascend310_info != nullptr);
|
||||||
auto device_id = ascend310_info->GetDeviceID();
|
auto device_id = ascend310_info->GetDeviceID();
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ TEST_F(TestDynamicBatchSize, InferMindIR) {
|
||||||
auto context = ContextAutoSet();
|
auto context = ContextAutoSet();
|
||||||
ASSERT_TRUE(context != nullptr);
|
ASSERT_TRUE(context != nullptr);
|
||||||
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
|
||||||
ASSERT_TRUE(ascend310_info != nullptr);
|
ASSERT_TRUE(ascend310_info != nullptr);
|
||||||
|
|
||||||
std::map<int, std::vector<int>> input_shape;
|
std::map<int, std::vector<int>> input_shape;
|
||||||
|
|
|
@ -59,7 +59,7 @@ TEST_F(TestZeroCopy, TestMindIR) {
|
||||||
auto context = ContextAutoSet();
|
auto context = ContextAutoSet();
|
||||||
ASSERT_TRUE(context != nullptr);
|
ASSERT_TRUE(context != nullptr);
|
||||||
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
|
||||||
ASSERT_TRUE(ascend310_info != nullptr);
|
ASSERT_TRUE(ascend310_info != nullptr);
|
||||||
ascend310_info->SetInsertOpConfigPath(aipp_path);
|
ascend310_info->SetInsertOpConfigPath(aipp_path);
|
||||||
auto device_id = ascend310_info->GetDeviceID();
|
auto device_id = ascend310_info->GetDeviceID();
|
||||||
|
@ -107,7 +107,7 @@ TEST_F(TestZeroCopy, TestDeviceTensor) {
|
||||||
auto context = ContextAutoSet();
|
auto context = ContextAutoSet();
|
||||||
ASSERT_TRUE(context != nullptr);
|
ASSERT_TRUE(context != nullptr);
|
||||||
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||||
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
|
||||||
ASSERT_TRUE(ascend310_info != nullptr);
|
ASSERT_TRUE(ascend310_info != nullptr);
|
||||||
ascend310_info->SetInsertOpConfigPath(aipp_path);
|
ascend310_info->SetInsertOpConfigPath(aipp_path);
|
||||||
auto device_id = ascend310_info->GetDeviceID();
|
auto device_id = ascend310_info->GetDeviceID();
|
||||||
|
|
|
@ -27,32 +27,27 @@ TEST_F(TestCxxApiContext, test_context_device_info_cast_SUCCESS) {
|
||||||
std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>();
|
std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>();
|
||||||
std::shared_ptr<DeviceInfoContext> gpu = std::make_shared<GPUDeviceInfo>();
|
std::shared_ptr<DeviceInfoContext> gpu = std::make_shared<GPUDeviceInfo>();
|
||||||
std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>();
|
std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>();
|
||||||
std::shared_ptr<DeviceInfoContext> ascend310 = std::make_shared<Ascend310DeviceInfo>();
|
std::shared_ptr<DeviceInfoContext> ascend = std::make_shared<AscendDeviceInfo>();
|
||||||
std::shared_ptr<DeviceInfoContext> ascend910 = std::make_shared<Ascend910DeviceInfo>();
|
|
||||||
|
|
||||||
ASSERT_TRUE(cpu->Cast<CPUDeviceInfo>() != nullptr);
|
ASSERT_TRUE(cpu->Cast<CPUDeviceInfo>() != nullptr);
|
||||||
ASSERT_TRUE(gpu->Cast<GPUDeviceInfo>() != nullptr);
|
ASSERT_TRUE(gpu->Cast<GPUDeviceInfo>() != nullptr);
|
||||||
ASSERT_TRUE(kirin_npu->Cast<KirinNPUDeviceInfo>() != nullptr);
|
ASSERT_TRUE(kirin_npu->Cast<KirinNPUDeviceInfo>() != nullptr);
|
||||||
ASSERT_TRUE(ascend310->Cast<Ascend310DeviceInfo>() != nullptr);
|
ASSERT_TRUE(ascend->Cast<AscendDeviceInfo>() != nullptr);
|
||||||
ASSERT_TRUE(ascend910->Cast<Ascend910DeviceInfo>() != nullptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestCxxApiContext, test_context_device_info_cast_FAILED) {
|
TEST_F(TestCxxApiContext, test_context_device_info_cast_FAILED) {
|
||||||
std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>();
|
std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>();
|
||||||
std::shared_ptr<DeviceInfoContext> gpu = std::make_shared<GPUDeviceInfo>();
|
std::shared_ptr<DeviceInfoContext> gpu = std::make_shared<GPUDeviceInfo>();
|
||||||
std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>();
|
std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>();
|
||||||
std::shared_ptr<DeviceInfoContext> ascend310 = std::make_shared<Ascend310DeviceInfo>();
|
std::shared_ptr<DeviceInfoContext> ascend = std::make_shared<AscendDeviceInfo>();
|
||||||
std::shared_ptr<DeviceInfoContext> ascend910 = std::make_shared<Ascend910DeviceInfo>();
|
|
||||||
|
|
||||||
ASSERT_TRUE(cpu->Cast<GPUDeviceInfo>() == nullptr);
|
ASSERT_TRUE(cpu->Cast<GPUDeviceInfo>() == nullptr);
|
||||||
ASSERT_TRUE(kirin_npu->Cast<GPUDeviceInfo>() == nullptr);
|
ASSERT_TRUE(kirin_npu->Cast<GPUDeviceInfo>() == nullptr);
|
||||||
ASSERT_TRUE(ascend310->Cast<GPUDeviceInfo>() == nullptr);
|
ASSERT_TRUE(ascend->Cast<GPUDeviceInfo>() == nullptr);
|
||||||
ASSERT_TRUE(ascend910->Cast<GPUDeviceInfo>() == nullptr);
|
|
||||||
|
|
||||||
ASSERT_TRUE(gpu->Cast<CPUDeviceInfo>() == nullptr);
|
ASSERT_TRUE(gpu->Cast<CPUDeviceInfo>() == nullptr);
|
||||||
ASSERT_TRUE(kirin_npu->Cast<CPUDeviceInfo>() == nullptr);
|
ASSERT_TRUE(kirin_npu->Cast<CPUDeviceInfo>() == nullptr);
|
||||||
ASSERT_TRUE(ascend310->Cast<CPUDeviceInfo>() == nullptr);
|
ASSERT_TRUE(ascend->Cast<CPUDeviceInfo>() == nullptr);
|
||||||
ASSERT_TRUE(ascend910->Cast<CPUDeviceInfo>() == nullptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestCxxApiContext, test_context_get_set_SUCCESS) {
|
TEST_F(TestCxxApiContext, test_context_get_set_SUCCESS) {
|
||||||
|
@ -86,7 +81,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) {
|
||||||
std::string option_9_ans = "1,2,3,4,5";
|
std::string option_9_ans = "1,2,3,4,5";
|
||||||
|
|
||||||
auto context = std::make_shared<Context>();
|
auto context = std::make_shared<Context>();
|
||||||
std::shared_ptr<Ascend310DeviceInfo> ascend310 = std::make_shared<Ascend310DeviceInfo>();
|
std::shared_ptr<AscendDeviceInfo> ascend310 = std::make_shared<AscendDeviceInfo>();
|
||||||
ascend310->SetInputShape(option_1);
|
ascend310->SetInputShape(option_1);
|
||||||
ascend310->SetInsertOpConfigPath(option_2);
|
ascend310->SetInsertOpConfigPath(option_2);
|
||||||
ascend310->SetOpSelectImplMode(option_3);
|
ascend310->SetOpSelectImplMode(option_3);
|
||||||
|
@ -99,7 +94,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) {
|
||||||
|
|
||||||
context->MutableDeviceInfo().push_back(ascend310);
|
context->MutableDeviceInfo().push_back(ascend310);
|
||||||
ASSERT_EQ(context->MutableDeviceInfo().size(), 1);
|
ASSERT_EQ(context->MutableDeviceInfo().size(), 1);
|
||||||
auto ctx = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
auto ctx = context->MutableDeviceInfo()[0]->Cast<AscendDeviceInfo>();
|
||||||
ASSERT_TRUE(ctx != nullptr);
|
ASSERT_TRUE(ctx != nullptr);
|
||||||
ASSERT_EQ(ascend310->GetInputShape(), option_1);
|
ASSERT_EQ(ascend310->GetInputShape(), option_1);
|
||||||
ASSERT_EQ(ascend310->GetInsertOpConfigPath(), option_2);
|
ASSERT_EQ(ascend310->GetInsertOpConfigPath(), option_2);
|
||||||
|
@ -113,7 +108,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestCxxApiContext, test_context_ascend310_context_default_value_SUCCESS) {
|
TEST_F(TestCxxApiContext, test_context_ascend310_context_default_value_SUCCESS) {
|
||||||
auto ctx = std::make_shared<Ascend310DeviceInfo>();
|
auto ctx = std::make_shared<AscendDeviceInfo>();
|
||||||
ASSERT_EQ(ctx->GetOpSelectImplMode(), "");
|
ASSERT_EQ(ctx->GetOpSelectImplMode(), "");
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue