2020-12-01 18:35:15 +08:00
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
#include <string>
#include <memory>
2021-02-19 17:09:40 +08:00
#include <vector>
2021-03-08 10:28:44 +08:00
#include <map>
2020-12-01 18:35:15 +08:00
#include "include/api/types.h"
2021-02-19 17:09:40 +08:00
#include "include/api/dual_abi_helper.h"
2020-12-01 18:35:15 +08:00
namespace mindspore {
2021-03-18 21:01:55 +08:00
enum DeviceType {
kCPU = 0,
// add new type here
kInvalidDeviceType = 100,
class Allocator;
class DeviceInfoContext;
2021-01-26 17:06:34 +08:00
2021-03-18 21:01:55 +08:00
class MS_API Context {
2021-02-19 17:09:40 +08:00
2021-03-18 21:01:55 +08:00
~Context() = default;
void SetThreadNum(int32_t thread_num);
int32_t GetThreadNum() const;
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
std::shared_ptr<Allocator> GetAllocator() const;
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
2021-02-19 17:09:40 +08:00
struct Data;
2021-03-18 21:01:55 +08:00
std::shared_ptr<Data> data_;
2021-01-26 17:06:34 +08:00
2021-03-18 21:01:55 +08:00
class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
2021-02-19 17:09:40 +08:00
2021-03-18 21:01:55 +08:00
struct Data;
2021-01-26 17:06:34 +08:00
2021-03-18 21:01:55 +08:00
virtual ~DeviceInfoContext() = default;
virtual enum DeviceType GetDeviceType() const = 0;
2021-01-26 17:06:34 +08:00
2021-03-18 21:01:55 +08:00
template <class T>
std::shared_ptr<T> Cast() {
static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
if (GetDeviceType() != T().GetDeviceType()) {
return nullptr;
2021-02-19 17:09:40 +08:00
2021-03-18 21:01:55 +08:00
return std::static_pointer_cast<T>(shared_from_this());
2021-03-04 11:30:53 +08:00
2021-03-18 21:01:55 +08:00
std::shared_ptr<Data> data_;
2021-03-04 11:30:53 +08:00
2021-03-18 21:01:55 +08:00
class MS_API CPUDeviceInfo : public DeviceInfoContext {
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
2021-03-23 09:40:31 +08:00
/// \brief Set the thread affinity to CPU cores.
2021-03-18 21:01:55 +08:00
/// \param mode: 0: no affinities, 1: big cores first, 2: little cores first
void SetThreadAffinity(int mode);
int GetThreadAffinity() const;
void SetEnableFP16(bool is_fp16);
bool GetEnableFP16() const;
2021-01-26 17:06:34 +08:00
2021-03-18 21:01:55 +08:00
class MS_API MaliGPUDeviceInfo : public DeviceInfoContext {
2021-02-19 17:09:40 +08:00
2021-03-18 21:01:55 +08:00
enum DeviceType GetDeviceType() const override { return DeviceType::kMaliGPU; };
2021-01-26 17:06:34 +08:00
2021-03-18 21:01:55 +08:00
void SetEnableFP16(bool is_fp16);
bool GetEnableFP16() const;
2021-01-26 17:06:34 +08:00
2021-03-18 21:01:55 +08:00
class MS_API KirinNPUDeviceInfo : public DeviceInfoContext {
enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
2021-01-26 17:06:34 +08:00
2021-03-18 21:01:55 +08:00
void SetFrequency(int frequency);
int GetFrequency() const;
2021-03-08 10:28:44 +08:00
2021-03-18 21:01:55 +08:00
class MS_API NvidiaGPUDeviceInfo : public DeviceInfoContext {
enum DeviceType GetDeviceType() const override { return DeviceType::kNvidiaGPU; };
2021-03-08 10:28:44 +08:00
2021-03-18 21:01:55 +08:00
void SetDeviceID(uint32_t device_id);
uint32_t GetDeviceID() const;
2021-01-26 17:06:34 +08:00
2021-03-18 21:01:55 +08:00
void SetGpuTrtInferMode(bool gpu_trt_infer_mode);
bool GetGpuTrtInferMode() const;
2021-04-25 09:53:47 +08:00
inline void SetPrecisionMode(const std::string &precison_mode);
inline std::string GetPrecisionMode() const;
void SetPrecisionMode(const std::vector<char> &precision_mode);
std::vector<char> GetPrecisionModeChar() const;
2021-03-18 21:01:55 +08:00
2021-02-19 17:09:40 +08:00
2021-04-25 09:53:47 +08:00
void NvidiaGPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
std::string NvidiaGPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
2021-03-18 21:01:55 +08:00
class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
2021-02-19 17:09:40 +08:00
2021-03-18 21:01:55 +08:00
void SetDeviceID(uint32_t device_id);
uint32_t GetDeviceID() const;
2021-03-04 11:30:53 +08:00
2021-03-18 21:01:55 +08:00
class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };
2021-03-04 15:08:26 +08:00
2021-03-18 21:01:55 +08:00
void SetDeviceID(uint32_t device_id);
uint32_t GetDeviceID() const;
2021-02-19 17:09:40 +08:00
2021-03-18 21:01:55 +08:00
inline void SetDumpConfigPath(const std::string &cfg_path);
inline std::string GetDumpConfigPath() const;
2021-02-19 17:09:40 +08:00
2021-03-07 17:06:18 +08:00
// aipp config file
2021-03-18 21:01:55 +08:00
inline void SetInsertOpConfigPath(const std::string &cfg_path);
inline std::string GetInsertOpConfigPath() const;
2021-02-19 17:09:40 +08:00
2021-03-07 17:06:18 +08:00
// nchw or nhwc
2021-03-18 21:01:55 +08:00
inline void SetInputFormat(const std::string &format);
inline std::string GetInputFormat() const;
2021-01-26 17:06:34 +08:00
2021-03-07 17:06:18 +08:00
// Mandatory while dynamic batch: e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1"
2021-03-18 21:01:55 +08:00
inline void SetInputShape(const std::string &shape);
inline std::string GetInputShape() const;
2021-03-04 11:30:53 +08:00
2021-03-18 21:01:55 +08:00
void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
std::map<int, std::vector<int>> GetInputShapeMap() const;
2021-03-04 15:08:26 +08:00
2021-03-18 21:01:55 +08:00
void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
inline std::string GetDynamicBatchSize() const;
2021-02-19 17:09:40 +08:00
2021-03-07 17:06:18 +08:00
// FP32, UINT8 or FP16, default as FP32
2021-03-18 21:01:55 +08:00
void SetOutputType(enum DataType output_type);
enum DataType GetOutputType() const;
2021-02-19 17:09:40 +08:00
2021-03-07 17:06:18 +08:00
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
2021-03-18 21:01:55 +08:00
inline void SetPrecisionMode(const std::string &precision_mode);
inline std::string GetPrecisionMode() const;
2021-03-04 11:30:53 +08:00
2021-03-07 17:06:18 +08:00
// Optional "high_performance" and "high_precision", "high_performance" is set as default
2021-03-18 21:01:55 +08:00
inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
inline std::string GetOpSelectImplMode() const;
2021-02-19 17:09:40 +08:00
2021-03-18 21:01:55 +08:00
inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
inline std::string GetFusionSwitchConfigPath() const;
2021-02-19 17:09:40 +08:00
2021-03-07 17:06:18 +08:00
// Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize"
inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode);
inline std::string GetBufferOptimizeMode() const;
2021-03-18 21:01:55 +08:00
void SetDumpConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetDumpConfigPathChar() const;
2021-02-19 17:09:40 +08:00
2021-03-18 21:01:55 +08:00
void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetInsertOpConfigPathChar() const;
2021-02-19 17:09:40 +08:00
2021-03-18 21:01:55 +08:00
void SetInputFormat(const std::vector<char> &format);
std::vector<char> GetInputFormatChar() const;
2021-03-04 11:30:53 +08:00
2021-03-18 21:01:55 +08:00
void SetInputShape(const std::vector<char> &shape);
std::vector<char> GetInputShapeChar() const;
std::vector<char> GetDynamicBatchSizeChar() const;
void SetPrecisionMode(const std::vector<char> &precision_mode);
std::vector<char> GetPrecisionModeChar() const;
void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
std::vector<char> GetOpSelectImplModeChar() const;
void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
std::vector<char> GetFusionSwitchConfigPathChar() const;
2021-03-07 17:06:18 +08:00
void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode);
std::vector<char> GetBufferOptimizeModeChar() const;
2021-03-18 21:01:55 +08:00
void Ascend310DeviceInfo::SetDumpConfigPath(const std::string &cfg_path) { SetDumpConfigPath(StringToChar(cfg_path)); }
std::string Ascend310DeviceInfo::GetDumpConfigPath() const { return CharToString(GetDumpConfigPathChar()); }
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
2021-03-04 11:30:53 +08:00
2021-03-18 21:01:55 +08:00
std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
2021-03-04 11:30:53 +08:00
2021-03-18 21:01:55 +08:00
std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
2021-03-04 15:08:26 +08:00
2021-03-18 21:01:55 +08:00
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
2021-03-08 10:28:44 +08:00
2021-03-18 21:01:55 +08:00
std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
2021-03-08 10:28:44 +08:00
2021-03-18 21:01:55 +08:00
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
2021-03-04 15:08:26 +08:00
2021-03-18 21:01:55 +08:00
std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const {
return CharToString(GetFusionSwitchConfigPathChar());
2021-03-04 15:08:26 +08:00
2021-03-07 17:06:18 +08:00
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
2020-12-01 18:35:15 +08:00
} // namespace mindspore