refactor cpp context, add string tensor, add get tensor by name
This commit is contained in:
parent
7d5701c3e9
commit
693b4cfdc8
|
@ -148,7 +148,7 @@ if(PLATFORM_ARM64)
|
||||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend* ops*" EXCLUDE)
|
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
|
||||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/operator_library DESTINATION ${CODEGEN_ROOT_DIR}
|
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/operator_library DESTINATION ${CODEGEN_ROOT_DIR}
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
if(ENABLE_TOOLS)
|
if(ENABLE_TOOLS)
|
||||||
|
@ -173,7 +173,7 @@ elseif(PLATFORM_ARM32)
|
||||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
|
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
|
||||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/operator_library DESTINATION ${CODEGEN_ROOT_DIR}
|
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/operator_library DESTINATION ${CODEGEN_ROOT_DIR}
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
if(ENABLE_TOOLS)
|
if(ENABLE_TOOLS)
|
||||||
|
@ -213,7 +213,7 @@ elseif(WIN32)
|
||||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
|
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
|
||||||
install(FILES ${TOP_DIR}/build/mindspore/src/libmindspore-lite.a DESTINATION ${RUNTIME_LIB_DIR}
|
install(FILES ${TOP_DIR}/build/mindspore/src/libmindspore-lite.a DESTINATION ${RUNTIME_LIB_DIR}
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${TOP_DIR}/build/mindspore/src/libmindspore-lite.dll.a DESTINATION ${RUNTIME_LIB_DIR}
|
install(FILES ${TOP_DIR}/build/mindspore/src/libmindspore-lite.dll.a DESTINATION ${RUNTIME_LIB_DIR}
|
||||||
|
@ -231,7 +231,7 @@ else()
|
||||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
|
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${RUNTIME_LIB_DIR}
|
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${RUNTIME_LIB_DIR}
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.a DESTINATION ${RUNTIME_LIB_DIR}
|
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.a DESTINATION ${RUNTIME_LIB_DIR}
|
||||||
|
|
|
@ -103,8 +103,9 @@ class MS_API GraphCell final : public Cell<GraphCell> {
|
||||||
std::vector<MSTensor> GetOutputs();
|
std::vector<MSTensor> GetOutputs();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class Model;
|
||||||
friend class ModelImpl;
|
friend class ModelImpl;
|
||||||
Status Load();
|
Status Load(uint32_t device_id);
|
||||||
|
|
||||||
std::shared_ptr<Graph> graph_;
|
std::shared_ptr<Graph> graph_;
|
||||||
std::shared_ptr<GraphImpl> executor_;
|
std::shared_ptr<GraphImpl> executor_;
|
||||||
|
|
|
@ -24,162 +24,201 @@
|
||||||
#include "include/api/dual_abi_helper.h"
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
constexpr auto kDeviceTypeAscend310 = "Ascend310";
|
enum DeviceType {
|
||||||
constexpr auto kDeviceTypeAscend910 = "Ascend910";
|
kCPU = 0,
|
||||||
constexpr auto kDeviceTypeGPU = "GPU";
|
kMaliGPU,
|
||||||
|
kNvidiaGPU,
|
||||||
|
kKirinNPU,
|
||||||
|
kAscend910,
|
||||||
|
kAscend310,
|
||||||
|
// add new type here
|
||||||
|
kInvalidDeviceType = 100,
|
||||||
|
};
|
||||||
|
|
||||||
struct MS_API Context {
|
class Allocator;
|
||||||
|
class DeviceInfoContext;
|
||||||
|
|
||||||
|
class MS_API Context {
|
||||||
public:
|
public:
|
||||||
Context();
|
Context();
|
||||||
virtual ~Context() = default;
|
~Context() = default;
|
||||||
|
|
||||||
|
void SetThreadNum(int32_t thread_num);
|
||||||
|
int32_t GetThreadNum() const;
|
||||||
|
|
||||||
|
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
|
||||||
|
std::shared_ptr<Allocator> GetAllocator() const;
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
|
||||||
|
|
||||||
|
private:
|
||||||
struct Data;
|
struct Data;
|
||||||
std::shared_ptr<Data> data;
|
std::shared_ptr<Data> data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MS_API GlobalContext : public Context {
|
class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
|
||||||
public:
|
public:
|
||||||
static std::shared_ptr<Context> GetGlobalContext();
|
struct Data;
|
||||||
|
|
||||||
static inline void SetGlobalDeviceTarget(const std::string &device_target);
|
DeviceInfoContext();
|
||||||
static inline std::string GetGlobalDeviceTarget();
|
virtual ~DeviceInfoContext() = default;
|
||||||
|
virtual enum DeviceType GetDeviceType() const = 0;
|
||||||
|
|
||||||
static void SetGlobalDeviceID(const uint32_t &device_id);
|
template <class T>
|
||||||
static uint32_t GetGlobalDeviceID();
|
std::shared_ptr<T> Cast() {
|
||||||
|
static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
|
||||||
|
if (GetDeviceType() != T().GetDeviceType()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
static inline void SetGlobalDumpConfigPath(const std::string &cfg_path);
|
return std::static_pointer_cast<T>(shared_from_this());
|
||||||
static inline std::string GetGlobalDumpConfigPath();
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::shared_ptr<Data> data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MS_API CPUDeviceInfo : public DeviceInfoContext {
|
||||||
|
public:
|
||||||
|
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
|
||||||
|
|
||||||
|
/// \brief Set the thread affinity of CPU cores.
|
||||||
|
///
|
||||||
|
/// \param mode: 0: no affinities, 1: big cores first, 2: little cores first
|
||||||
|
void SetThreadAffinity(int mode);
|
||||||
|
int GetThreadAffinity() const;
|
||||||
|
void SetEnableFP16(bool is_fp16);
|
||||||
|
bool GetEnableFP16() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MS_API MaliGPUDeviceInfo : public DeviceInfoContext {
|
||||||
|
public:
|
||||||
|
enum DeviceType GetDeviceType() const override { return DeviceType::kMaliGPU; };
|
||||||
|
|
||||||
|
void SetEnableFP16(bool is_fp16);
|
||||||
|
bool GetEnableFP16() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MS_API KirinNPUDeviceInfo : public DeviceInfoContext {
|
||||||
|
public:
|
||||||
|
enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
|
||||||
|
|
||||||
|
void SetFrequency(int frequency);
|
||||||
|
int GetFrequency() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MS_API NvidiaGPUDeviceInfo : public DeviceInfoContext {
|
||||||
|
public:
|
||||||
|
enum DeviceType GetDeviceType() const override { return DeviceType::kNvidiaGPU; };
|
||||||
|
|
||||||
|
void SetDeviceID(uint32_t device_id);
|
||||||
|
uint32_t GetDeviceID() const;
|
||||||
|
|
||||||
|
void SetGpuTrtInferMode(bool gpu_trt_infer_mode);
|
||||||
|
bool GetGpuTrtInferMode() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
|
||||||
|
public:
|
||||||
|
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
|
||||||
|
|
||||||
|
void SetDeviceID(uint32_t device_id);
|
||||||
|
uint32_t GetDeviceID() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
|
||||||
|
public:
|
||||||
|
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };
|
||||||
|
|
||||||
|
void SetDeviceID(uint32_t device_id);
|
||||||
|
uint32_t GetDeviceID() const;
|
||||||
|
|
||||||
|
inline void SetDumpConfigPath(const std::string &cfg_path);
|
||||||
|
inline std::string GetDumpConfigPath() const;
|
||||||
|
|
||||||
|
inline void SetInsertOpConfigPath(const std::string &cfg_path);
|
||||||
|
inline std::string GetInsertOpConfigPath() const;
|
||||||
|
|
||||||
|
inline void SetInputFormat(const std::string &format);
|
||||||
|
inline std::string GetInputFormat() const;
|
||||||
|
|
||||||
|
inline void SetInputShape(const std::string &shape);
|
||||||
|
inline std::string GetInputShape() const;
|
||||||
|
|
||||||
|
void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
|
||||||
|
std::map<int, std::vector<int>> GetInputShapeMap() const;
|
||||||
|
|
||||||
|
void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
|
||||||
|
inline std::string GetDynamicBatchSize() const;
|
||||||
|
|
||||||
|
void SetOutputType(enum DataType output_type);
|
||||||
|
enum DataType GetOutputType() const;
|
||||||
|
|
||||||
|
inline void SetPrecisionMode(const std::string &precision_mode);
|
||||||
|
inline std::string GetPrecisionMode() const;
|
||||||
|
|
||||||
|
inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
|
||||||
|
inline std::string GetOpSelectImplMode() const;
|
||||||
|
|
||||||
|
inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
|
||||||
|
inline std::string GetFusionSwitchConfigPath() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// api without std::string
|
void SetDumpConfigPath(const std::vector<char> &cfg_path);
|
||||||
static void SetGlobalDeviceTarget(const std::vector<char> &device_target);
|
std::vector<char> GetDumpConfigPathChar() const;
|
||||||
static std::vector<char> GetGlobalDeviceTargetChar();
|
|
||||||
|
|
||||||
static void SetGlobalDumpConfigPath(const std::vector<char> &cfg_path);
|
void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
|
||||||
static std::vector<char> GetGlobalDumpConfigPathChar();
|
std::vector<char> GetInsertOpConfigPathChar() const;
|
||||||
|
|
||||||
|
void SetInputFormat(const std::vector<char> &format);
|
||||||
|
std::vector<char> GetInputFormatChar() const;
|
||||||
|
|
||||||
|
void SetInputShape(const std::vector<char> &shape);
|
||||||
|
std::vector<char> GetInputShapeChar() const;
|
||||||
|
|
||||||
|
std::vector<char> GetDynamicBatchSizeChar() const;
|
||||||
|
|
||||||
|
void SetPrecisionMode(const std::vector<char> &precision_mode);
|
||||||
|
std::vector<char> GetPrecisionModeChar() const;
|
||||||
|
|
||||||
|
void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
|
||||||
|
std::vector<char> GetOpSelectImplModeChar() const;
|
||||||
|
|
||||||
|
void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
|
||||||
|
std::vector<char> GetFusionSwitchConfigPathChar() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MS_API ModelContext : public Context {
|
void Ascend310DeviceInfo::SetDumpConfigPath(const std::string &cfg_path) { SetDumpConfigPath(StringToChar(cfg_path)); }
|
||||||
public:
|
std::string Ascend310DeviceInfo::GetDumpConfigPath() const { return CharToString(GetDumpConfigPathChar()); }
|
||||||
static inline void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
|
|
||||||
static inline std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static inline void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format);
|
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
|
||||||
static inline std::string GetInputFormat(const std::shared_ptr<Context> &context);
|
SetInsertOpConfigPath(StringToChar(cfg_path));
|
||||||
|
|
||||||
static inline void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape);
|
|
||||||
static inline std::string GetInputShape(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void SetInputShapeMap(const std::shared_ptr<Context> &context, const std::map<int, std::vector<int>> &shape);
|
|
||||||
static std::map<int, std::vector<int>> GetInputShapeMap(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void SetDynamicBatchSize(const std::shared_ptr<Context> &context,
|
|
||||||
const std::vector<size_t> &dynamic_batch_size);
|
|
||||||
static inline std::string GetDynamicBatchSize(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type);
|
|
||||||
static enum DataType GetOutputType(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static inline void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode);
|
|
||||||
static inline std::string GetPrecisionMode(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static inline void SetOpSelectImplMode(const std::shared_ptr<Context> &context,
|
|
||||||
const std::string &op_select_impl_mode);
|
|
||||||
static inline std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static inline void SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
|
|
||||||
static inline std::string GetFusionSwitchConfigPath(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static inline void SetGpuTrtInferMode(const std::shared_ptr<Context> &context, const std::string &gpu_trt_infer_mode);
|
|
||||||
static inline std::string GetGpuTrtInferMode(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
private:
|
|
||||||
// api without std::string
|
|
||||||
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path);
|
|
||||||
static std::vector<char> GetInsertOpConfigPathChar(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void SetInputFormat(const std::shared_ptr<Context> &context, const std::vector<char> &format);
|
|
||||||
static std::vector<char> GetInputFormatChar(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void SetInputShape(const std::shared_ptr<Context> &context, const std::vector<char> &shape);
|
|
||||||
static std::vector<char> GetInputShapeChar(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::vector<char> &precision_mode);
|
|
||||||
static std::vector<char> GetPrecisionModeChar(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void SetOpSelectImplMode(const std::shared_ptr<Context> &context,
|
|
||||||
const std::vector<char> &op_select_impl_mode);
|
|
||||||
static std::vector<char> GetOpSelectImplModeChar(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path);
|
|
||||||
static std::vector<char> GetFusionSwitchConfigPathChar(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void SetGpuTrtInferMode(const std::shared_ptr<Context> &context, const std::vector<char> &gpu_trt_infer_mode);
|
|
||||||
static std::vector<char> GetGpuTrtInferModeChar(const std::shared_ptr<Context> &context);
|
|
||||||
static std::vector<char> GetDynamicBatchSizeChar(const std::shared_ptr<Context> &context);
|
|
||||||
};
|
|
||||||
|
|
||||||
void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
|
|
||||||
SetGlobalDeviceTarget(StringToChar(device_target));
|
|
||||||
}
|
}
|
||||||
std::string GlobalContext::GetGlobalDeviceTarget() { return CharToString(GetGlobalDeviceTargetChar()); }
|
std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
|
||||||
|
|
||||||
void GlobalContext::SetGlobalDumpConfigPath(const std::string &cfg_path) {
|
void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
|
||||||
SetGlobalDumpConfigPath(StringToChar(cfg_path));
|
std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
|
||||||
}
|
|
||||||
std::string GlobalContext::GetGlobalDumpConfigPath() { return CharToString(GetGlobalDumpConfigPathChar()); }
|
|
||||||
|
|
||||||
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
|
void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
|
||||||
SetInsertOpConfigPath(context, StringToChar(cfg_path));
|
std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
|
||||||
}
|
|
||||||
std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) {
|
|
||||||
return CharToString(GetInsertOpConfigPathChar(context));
|
|
||||||
}
|
|
||||||
|
|
||||||
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) {
|
std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
|
||||||
SetInputFormat(context, StringToChar(format));
|
|
||||||
}
|
|
||||||
std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) {
|
|
||||||
return CharToString(GetInputFormatChar(context));
|
|
||||||
}
|
|
||||||
|
|
||||||
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) {
|
void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
||||||
SetInputShape(context, StringToChar(shape));
|
SetPrecisionMode(StringToChar(precision_mode));
|
||||||
}
|
|
||||||
std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) {
|
|
||||||
return CharToString(GetInputShapeChar(context));
|
|
||||||
}
|
}
|
||||||
|
std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
||||||
|
|
||||||
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) {
|
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
|
||||||
SetPrecisionMode(context, StringToChar(precision_mode));
|
SetOpSelectImplMode(StringToChar(op_select_impl_mode));
|
||||||
}
|
|
||||||
std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) {
|
|
||||||
return CharToString(GetPrecisionModeChar(context));
|
|
||||||
}
|
}
|
||||||
|
std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
|
||||||
|
|
||||||
void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
|
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
|
||||||
const std::string &op_select_impl_mode) {
|
SetFusionSwitchConfigPath(StringToChar(cfg_path));
|
||||||
SetOpSelectImplMode(context, StringToChar(op_select_impl_mode));
|
|
||||||
}
|
}
|
||||||
std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
|
std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const {
|
||||||
return CharToString(GetOpSelectImplModeChar(context));
|
return CharToString(GetFusionSwitchConfigPathChar());
|
||||||
}
|
|
||||||
|
|
||||||
void ModelContext::SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
|
|
||||||
SetFusionSwitchConfigPath(context, StringToChar(cfg_path));
|
|
||||||
}
|
|
||||||
std::string ModelContext::GetFusionSwitchConfigPath(const std::shared_ptr<Context> &context) {
|
|
||||||
return CharToString(GetFusionSwitchConfigPathChar(context));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string ModelContext::GetDynamicBatchSize(const std::shared_ptr<Context> &context) {
|
|
||||||
return CharToString(GetDynamicBatchSizeChar(context));
|
|
||||||
}
|
|
||||||
|
|
||||||
void ModelContext::SetGpuTrtInferMode(const std::shared_ptr<Context> &context, const std::string &gpu_trt_infer_mode) {
|
|
||||||
SetGpuTrtInferMode(context, StringToChar(gpu_trt_infer_mode));
|
|
||||||
}
|
|
||||||
std::string ModelContext::GetGpuTrtInferMode(const std::shared_ptr<Context> &context) {
|
|
||||||
return CharToString(GetGpuTrtInferModeChar(context));
|
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H
|
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||||
|
|
|
@ -27,6 +27,7 @@ namespace mindspore {
|
||||||
class MS_API Graph {
|
class MS_API Graph {
|
||||||
public:
|
public:
|
||||||
class GraphData;
|
class GraphData;
|
||||||
|
Graph();
|
||||||
explicit Graph(const std::shared_ptr<GraphData> &graph_data);
|
explicit Graph(const std::shared_ptr<GraphData> &graph_data);
|
||||||
explicit Graph(std::shared_ptr<GraphData> &&graph_data);
|
explicit Graph(std::shared_ptr<GraphData> &&graph_data);
|
||||||
explicit Graph(std::nullptr_t);
|
explicit Graph(std::nullptr_t);
|
||||||
|
@ -34,6 +35,7 @@ class MS_API Graph {
|
||||||
|
|
||||||
enum ModelType ModelType() const;
|
enum ModelType ModelType() const;
|
||||||
bool operator==(std::nullptr_t) const;
|
bool operator==(std::nullptr_t) const;
|
||||||
|
bool operator!=(std::nullptr_t) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class GraphCell;
|
friend class GraphCell;
|
||||||
|
|
|
@ -1,71 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
|
|
||||||
#define MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include <map>
|
|
||||||
#include <any>
|
|
||||||
#include "include/api/types.h"
|
|
||||||
#include "include/lite_types.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
class Allocator;
|
|
||||||
} // namespace lite
|
|
||||||
|
|
||||||
struct MS_API Context {
|
|
||||||
public:
|
|
||||||
static void Clear(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void SetAsDefault(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void SetVendorName(const std::shared_ptr<Context> &context, const std::string &name);
|
|
||||||
static std::string GetVendorName(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void SetThreadNum(const std::shared_ptr<Context> &context, int num);
|
|
||||||
static int GetThreadNum(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void SetAllocator(const std::shared_ptr<Context> &context, std::shared_ptr<lite::Allocator> alloc);
|
|
||||||
static std::shared_ptr<lite::Allocator> GetAllocator(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void ConfigCPU(const std::shared_ptr<Context> &context, bool config);
|
|
||||||
static bool IfCPUEnabled(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void ConfigCPUFp16(const std::shared_ptr<Context> &context, bool config);
|
|
||||||
static bool IfCPUFp16Enabled(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void SetCPUBindMode(const std::shared_ptr<Context> &context, lite::CpuBindMode mode);
|
|
||||||
static lite::CpuBindMode GetCPUBindMode(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void ConfigGPU(const std::shared_ptr<Context> &context, bool config);
|
|
||||||
static bool IfGPUEnabled(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void ConfigGPUFp16(const std::shared_ptr<Context> &context, bool config);
|
|
||||||
static bool IfGPUFp16Enabled(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void ConfigNPU(const std::shared_ptr<Context> &context, bool config);
|
|
||||||
static bool IfNPUEnabled(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
static void SetNPUFrequency(const std::shared_ptr<Context> &context, int freq);
|
|
||||||
static int GetNPUFrequency(const std::shared_ptr<Context> &context);
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::map<std::string, std::any> context_;
|
|
||||||
};
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
|
|
|
@ -24,39 +24,52 @@
|
||||||
#include "include/api/status.h"
|
#include "include/api/status.h"
|
||||||
#include "include/api/types.h"
|
#include "include/api/types.h"
|
||||||
#include "include/api/graph.h"
|
#include "include/api/graph.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
#include "include/api/cell.h"
|
#include "include/api/cell.h"
|
||||||
#include "include/api/dual_abi_helper.h"
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class ModelImpl;
|
class ModelImpl;
|
||||||
struct Context;
|
|
||||||
|
|
||||||
class MS_API Model {
|
class MS_API Model {
|
||||||
public:
|
public:
|
||||||
explicit Model(const std::vector<Output> &network, const std::shared_ptr<Context> &model_context = nullptr);
|
Model();
|
||||||
explicit Model(const GraphCell &graph, const std::shared_ptr<Context> &model_context = nullptr);
|
|
||||||
~Model();
|
~Model();
|
||||||
Model(const Model &) = delete;
|
Model(const Model &) = delete;
|
||||||
void operator=(const Model &) = delete;
|
void operator=(const Model &) = delete;
|
||||||
|
|
||||||
Status Build();
|
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr);
|
||||||
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
||||||
|
|
||||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
||||||
|
|
||||||
std::vector<MSTensor> GetInputs();
|
std::vector<MSTensor> GetInputs();
|
||||||
std::vector<MSTensor> GetOutputs();
|
inline MSTensor GetInputByTensorName(const std::string &tensor_name);
|
||||||
|
|
||||||
static inline bool CheckModelSupport(const std::string &device_type, ModelType model_type);
|
std::vector<MSTensor> GetOutputs();
|
||||||
|
inline std::vector<std::string> GetOutputTensorNames();
|
||||||
|
inline MSTensor GetOutputByTensorName(const std::string &tensor_name);
|
||||||
|
|
||||||
|
static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// api without std::string
|
// api without std::string
|
||||||
static bool CheckModelSupport(const std::vector<char> &device_type, ModelType model_type);
|
MSTensor GetInputByTensorName(const std::vector<char> &tensor_name);
|
||||||
|
std::vector<std::vector<char>> GetOutputTensorNamesChar();
|
||||||
|
MSTensor GetOutputByTensorName(const std::vector<char> &tensor_name);
|
||||||
|
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
|
||||||
|
|
||||||
std::shared_ptr<ModelImpl> impl_;
|
std::shared_ptr<ModelImpl> impl_;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) {
|
MSTensor Model::GetInputByTensorName(const std::string &tensor_name) {
|
||||||
return CheckModelSupport(StringToChar(device_type), model_type);
|
return GetInputByTensorName(StringToChar(tensor_name));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> Model::GetOutputTensorNames() { return VectorCharToString(GetOutputTensorNamesChar()); }
|
||||||
|
|
||||||
|
MSTensor Model::GetOutputByTensorName(const std::string &tensor_name) {
|
||||||
|
return GetOutputByTensorName(StringToChar(tensor_name));
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
||||||
|
|
|
@ -29,19 +29,19 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class MS_API Serialization {
|
class MS_API Serialization {
|
||||||
public:
|
public:
|
||||||
static Graph LoadModel(const void *model_data, size_t data_size, ModelType model_type);
|
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph);
|
||||||
inline static Graph LoadModel(const std::string &file, ModelType model_type);
|
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph);
|
||||||
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
|
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
|
||||||
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
||||||
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
||||||
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
|
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static Graph LoadModel(const std::vector<char> &file, ModelType model_type);
|
static Status Load(const std::vector<char> &file, ModelType model_type, Graph *graph);
|
||||||
};
|
};
|
||||||
|
|
||||||
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
|
Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph) {
|
||||||
return LoadModel(StringToChar(file), model_type);
|
return Load(StringToChar(file), model_type, graph);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H
|
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H
|
||||||
|
|
|
@ -43,15 +43,19 @@ class MS_API MSTensor {
|
||||||
public:
|
public:
|
||||||
class Impl;
|
class Impl;
|
||||||
|
|
||||||
static inline MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
static inline MSTensor *CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
||||||
const void *data, size_t data_len) noexcept;
|
const void *data, size_t data_len) noexcept;
|
||||||
static inline MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
static inline MSTensor *CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
||||||
const void *data, size_t data_len) noexcept;
|
const void *data, size_t data_len) noexcept;
|
||||||
|
static inline MSTensor *StringsToTensor(const std::string &name, const std::vector<std::string> &str);
|
||||||
|
static inline std::vector<std::string> TensorToStrings(const MSTensor &tensor);
|
||||||
|
static void DestroyTensorPtr(MSTensor *tensor) noexcept;
|
||||||
|
|
||||||
MSTensor();
|
MSTensor();
|
||||||
explicit MSTensor(const std::shared_ptr<Impl> &impl);
|
explicit MSTensor(const std::shared_ptr<Impl> &impl);
|
||||||
inline MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data,
|
inline MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||||
size_t data_len);
|
size_t data_len);
|
||||||
|
explicit MSTensor(std::nullptr_t);
|
||||||
~MSTensor();
|
~MSTensor();
|
||||||
|
|
||||||
inline std::string Name() const;
|
inline std::string Name() const;
|
||||||
|
@ -65,21 +69,24 @@ class MS_API MSTensor {
|
||||||
|
|
||||||
bool IsDevice() const;
|
bool IsDevice() const;
|
||||||
|
|
||||||
MSTensor Clone() const;
|
MSTensor *Clone() const;
|
||||||
bool operator==(std::nullptr_t) const;
|
bool operator==(std::nullptr_t) const;
|
||||||
|
bool operator!=(std::nullptr_t) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// api without std::string
|
// api without std::string
|
||||||
static MSTensor CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
static MSTensor *CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||||
const void *data, size_t data_len) noexcept;
|
const void *data, size_t data_len) noexcept;
|
||||||
static MSTensor CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
static MSTensor *CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||||
const void *data, size_t data_len) noexcept;
|
const void *data, size_t data_len) noexcept;
|
||||||
|
static MSTensor *CharStringsToTensor(const std::vector<char> &name, const std::vector<std::vector<char>> &str);
|
||||||
|
static std::vector<std::vector<char>> TensorToStringChars(const MSTensor &tensor);
|
||||||
|
|
||||||
MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||||
size_t data_len);
|
size_t data_len);
|
||||||
std::vector<char> CharName() const;
|
std::vector<char> CharName() const;
|
||||||
|
|
||||||
friend class ModelImpl;
|
friend class ModelImpl;
|
||||||
explicit MSTensor(std::nullptr_t);
|
|
||||||
std::shared_ptr<Impl> impl_;
|
std::shared_ptr<Impl> impl_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -103,16 +110,24 @@ class MS_API Buffer {
|
||||||
std::shared_ptr<Impl> impl_;
|
std::shared_ptr<Impl> impl_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
|
MSTensor *MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||||
const void *data, size_t data_len) noexcept {
|
const void *data, size_t data_len) noexcept {
|
||||||
return CreateTensor(StringToChar(name), type, shape, data, data_len);
|
return CreateTensor(StringToChar(name), type, shape, data, data_len);
|
||||||
}
|
}
|
||||||
|
|
||||||
MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
|
MSTensor *MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||||
const void *data, size_t data_len) noexcept {
|
const void *data, size_t data_len) noexcept {
|
||||||
return CreateRefTensor(StringToChar(name), type, shape, data, data_len);
|
return CreateRefTensor(StringToChar(name), type, shape, data, data_len);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MSTensor *MSTensor::StringsToTensor(const std::string &name, const std::vector<std::string> &str) {
|
||||||
|
return CharStringsToTensor(StringToChar(name), VectorStringToChar(str));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> MSTensor::TensorToStrings(const MSTensor &tensor) {
|
||||||
|
return VectorCharToString(TensorToStringChars(tensor));
|
||||||
|
}
|
||||||
|
|
||||||
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||||
size_t data_len)
|
size_t data_len)
|
||||||
: MSTensor(StringToChar(name), type, shape, data, data_len) {}
|
: MSTensor(StringToChar(name), type, shape, data, data_len) {}
|
||||||
|
|
|
@ -1,134 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_INFERENCE_LOG_H_
|
|
||||||
#define MINDSPORE_INFERENCE_LOG_H_
|
|
||||||
|
|
||||||
#include <stdarg.h>
|
|
||||||
#include <stdint.h>
|
|
||||||
#include <string>
|
|
||||||
#include <sstream>
|
|
||||||
#include <memory>
|
|
||||||
#include <iostream>
|
|
||||||
#include <chrono>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#ifndef ENABLE_ACL
|
|
||||||
#include "mindspore/core/utils/log_adapter.h"
|
|
||||||
#else // ENABLE_ACL
|
|
||||||
#include "acl/acl.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace mindspore::inference {
|
|
||||||
|
|
||||||
class LogStream {
|
|
||||||
public:
|
|
||||||
LogStream() { sstream_ = std::make_shared<std::stringstream>(); }
|
|
||||||
~LogStream() = default;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
LogStream &operator<<(const T &val) noexcept {
|
|
||||||
(*sstream_) << val;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
LogStream &operator<<(const std::vector<T> &val) noexcept {
|
|
||||||
(*sstream_) << "[";
|
|
||||||
for (size_t i = 0; i < val.size(); i++) {
|
|
||||||
(*this) << val[i];
|
|
||||||
if (i + 1 < val.size()) {
|
|
||||||
(*sstream_) << ", ";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(*sstream_) << "]";
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
LogStream &operator<<(std::ostream &func(std::ostream &os)) noexcept {
|
|
||||||
(*sstream_) << func;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
friend class LogWriter;
|
|
||||||
friend class Status;
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::shared_ptr<std::stringstream> sstream_;
|
|
||||||
};
|
|
||||||
|
|
||||||
#ifndef ENABLE_ACL
|
|
||||||
#define MSI_LOG(level) MS_LOG(level)
|
|
||||||
|
|
||||||
#define MSI_LOG_DEBUG MSI_LOG(DEBUG)
|
|
||||||
#define MSI_LOG_INFO MSI_LOG(INFO)
|
|
||||||
#define MSI_LOG_WARNING MSI_LOG(WARNING)
|
|
||||||
#define MSI_LOG_ERROR MSI_LOG(ERROR)
|
|
||||||
|
|
||||||
#define MSI_ASSERT(item) MS_ASSERT(item)
|
|
||||||
|
|
||||||
#else // ENABLE_ACL
|
|
||||||
|
|
||||||
class LogWriter {
|
|
||||||
public:
|
|
||||||
LogWriter(const char *file, int line, const char *func, aclLogLevel log_level)
|
|
||||||
: file_(file), line_(line), func_(func), log_level_(log_level) {}
|
|
||||||
~LogWriter() = default;
|
|
||||||
|
|
||||||
void operator<(const LogStream &stream) const noexcept __attribute__((visibility("default"))) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << stream.sstream_->rdbuf();
|
|
||||||
OutputLog(msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
void OutputLog(const std::ostringstream &msg) const { aclAppLog(log_level_, func_, file_, line_, msg.str().c_str()); }
|
|
||||||
|
|
||||||
const char *file_;
|
|
||||||
int line_;
|
|
||||||
const char *func_;
|
|
||||||
aclLogLevel log_level_;
|
|
||||||
};
|
|
||||||
|
|
||||||
#define MSILOG_IF(level) inference::LogWriter(__FILE__, __LINE__, __FUNCTION__, ACL_##level) < inference::LogStream()
|
|
||||||
|
|
||||||
#define MSI_LOG(level) MSI_LOG_##level
|
|
||||||
|
|
||||||
#define MSI_LOG_DEBUG MSILOG_IF(DEBUG)
|
|
||||||
#define MSI_LOG_INFO MSILOG_IF(INFO)
|
|
||||||
#define MSI_LOG_WARNING MSILOG_IF(WARNING)
|
|
||||||
#define MSI_LOG_ERROR MSILOG_IF(ERROR)
|
|
||||||
|
|
||||||
#define MSI_ASSERT(item)
|
|
||||||
|
|
||||||
#endif // ENABLE_ACL
|
|
||||||
|
|
||||||
#define MSI_TIME_STAMP_START(name) auto time_start_##name = std::chrono::steady_clock::now();
|
|
||||||
#define MSI_TIME_STAMP_END(name) \
|
|
||||||
{ \
|
|
||||||
auto time_end_##name = std::chrono::steady_clock::now(); \
|
|
||||||
auto time_cost = std::chrono::duration<double, std::milli>(time_end_##name - time_start_##name).count(); \
|
|
||||||
MSI_LOG_INFO << #name " Time Cost # " << time_cost << " ms ---------------------"; \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define INFER_STATUS(code) inference::Status(code) < inference::LogStream()
|
|
||||||
#define ERROR_INFER_STATUS(status, type, msg) \
|
|
||||||
MSI_LOG_ERROR << msg; \
|
|
||||||
status = inference::Status(type, msg)
|
|
||||||
|
|
||||||
} // namespace mindspore::inference
|
|
||||||
|
|
||||||
#endif // MINDSPORE_INFERENCE_LOG_H_
|
|
|
@ -1,217 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_INCLUDE_INFER_TENSOR_H_
|
|
||||||
#define MINDSPORE_INCLUDE_INFER_TENSOR_H_
|
|
||||||
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include <numeric>
|
|
||||||
#include <map>
|
|
||||||
#include <functional>
|
|
||||||
|
|
||||||
#include "securec/include/securec.h"
|
|
||||||
#include "include/infer_log.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
#define MS_API __attribute__((visibility("default")))
|
|
||||||
namespace inference {
|
|
||||||
enum DataType {
|
|
||||||
kMSI_Unknown = 0,
|
|
||||||
kMSI_Bool = 1,
|
|
||||||
kMSI_Int8 = 2,
|
|
||||||
kMSI_Int16 = 3,
|
|
||||||
kMSI_Int32 = 4,
|
|
||||||
kMSI_Int64 = 5,
|
|
||||||
kMSI_Uint8 = 6,
|
|
||||||
kMSI_Uint16 = 7,
|
|
||||||
kMSI_Uint32 = 8,
|
|
||||||
kMSI_Uint64 = 9,
|
|
||||||
kMSI_Float16 = 10,
|
|
||||||
kMSI_Float32 = 11,
|
|
||||||
kMSI_Float64 = 12,
|
|
||||||
};
|
|
||||||
|
|
||||||
class InferTensorBase {
|
|
||||||
public:
|
|
||||||
InferTensorBase() = default;
|
|
||||||
virtual ~InferTensorBase() = default;
|
|
||||||
|
|
||||||
virtual DataType data_type() const = 0;
|
|
||||||
virtual void set_data_type(DataType type) = 0;
|
|
||||||
virtual std::vector<int64_t> shape() const = 0;
|
|
||||||
virtual void set_shape(const std::vector<int64_t> &shape) = 0;
|
|
||||||
virtual const void *data() const = 0;
|
|
||||||
virtual size_t data_size() const = 0;
|
|
||||||
virtual bool resize_data(size_t data_len) = 0;
|
|
||||||
virtual void *mutable_data() = 0;
|
|
||||||
|
|
||||||
bool set_data(const void *data, size_t data_len) {
|
|
||||||
resize_data(data_len);
|
|
||||||
if (mutable_data() == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "set data failed, data len " << data_len;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (data_size() != data_len) {
|
|
||||||
MSI_LOG_ERROR << "set data failed, tensor current data size " << data_size() << " not match data len "
|
|
||||||
<< data_len;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (data_len == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
auto ret = memcpy_s(mutable_data(), data_size(), data, data_len);
|
|
||||||
if (ret != 0) {
|
|
||||||
MSI_LOG_ERROR << "Set data memcpy_s failed";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t ElementNum() const {
|
|
||||||
std::vector<int64_t> shapex = shape();
|
|
||||||
return std::accumulate(shapex.begin(), shapex.end(), 1LL, std::multiplies<int64_t>());
|
|
||||||
}
|
|
||||||
|
|
||||||
int GetTypeSize(DataType type) const {
|
|
||||||
const std::map<DataType, size_t> type_size_map{
|
|
||||||
{kMSI_Bool, sizeof(bool)}, {kMSI_Float64, sizeof(double)}, {kMSI_Int8, sizeof(int8_t)},
|
|
||||||
{kMSI_Uint8, sizeof(uint8_t)}, {kMSI_Int16, sizeof(int16_t)}, {kMSI_Uint16, sizeof(uint16_t)},
|
|
||||||
{kMSI_Int32, sizeof(int32_t)}, {kMSI_Uint32, sizeof(uint32_t)}, {kMSI_Int64, sizeof(int64_t)},
|
|
||||||
{kMSI_Uint64, sizeof(uint64_t)}, {kMSI_Float16, sizeof(uint16_t)}, {kMSI_Float32, sizeof(float)},
|
|
||||||
};
|
|
||||||
auto it = type_size_map.find(type);
|
|
||||||
if (it != type_size_map.end()) {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class InferTensor : public InferTensorBase {
|
|
||||||
public:
|
|
||||||
DataType type_;
|
|
||||||
std::vector<int64_t> shape_;
|
|
||||||
std::vector<uint8_t> data_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
InferTensor() = default;
|
|
||||||
~InferTensor() = default;
|
|
||||||
InferTensor(DataType type, std::vector<int64_t> shape, const void *data, size_t data_len) {
|
|
||||||
set_data_type(type);
|
|
||||||
set_shape(shape);
|
|
||||||
set_data(data, data_len);
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_data_type(DataType type) override { type_ = type; }
|
|
||||||
DataType data_type() const override { return type_; }
|
|
||||||
|
|
||||||
void set_shape(const std::vector<int64_t> &shape) override { shape_ = shape; }
|
|
||||||
std::vector<int64_t> shape() const override { return shape_; }
|
|
||||||
|
|
||||||
const void *data() const override { return data_.data(); }
|
|
||||||
size_t data_size() const override { return data_.size(); }
|
|
||||||
|
|
||||||
bool resize_data(size_t data_len) override {
|
|
||||||
data_.resize(data_len);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
void *mutable_data() override { return data_.data(); }
|
|
||||||
};
|
|
||||||
|
|
||||||
class InferImagesBase {
|
|
||||||
public:
|
|
||||||
InferImagesBase() = default;
|
|
||||||
virtual ~InferImagesBase() = default;
|
|
||||||
virtual size_t batch_size() const = 0;
|
|
||||||
virtual bool get(size_t index, const void *&pic_buffer, uint32_t &pic_size) const = 0;
|
|
||||||
virtual size_t input_index() const = 0; // the index of images as input in model
|
|
||||||
};
|
|
||||||
|
|
||||||
class RequestBase {
|
|
||||||
public:
|
|
||||||
RequestBase() = default;
|
|
||||||
virtual ~RequestBase() = default;
|
|
||||||
virtual size_t size() const = 0;
|
|
||||||
virtual const InferTensorBase *operator[](size_t index) const = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
class ImagesRequestBase {
|
|
||||||
public:
|
|
||||||
ImagesRequestBase() = default;
|
|
||||||
virtual ~ImagesRequestBase() = default;
|
|
||||||
virtual size_t size() const = 0;
|
|
||||||
virtual const InferImagesBase *operator[](size_t index) const = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
class ReplyBase {
|
|
||||||
public:
|
|
||||||
ReplyBase() = default;
|
|
||||||
virtual ~ReplyBase() = default;
|
|
||||||
virtual size_t size() const = 0;
|
|
||||||
virtual InferTensorBase *operator[](size_t index) = 0;
|
|
||||||
virtual const InferTensorBase *operator[](size_t index) const = 0;
|
|
||||||
virtual InferTensorBase *add() = 0;
|
|
||||||
virtual void clear() = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
class VectorInferTensorWrapReply : public ReplyBase {
|
|
||||||
public:
|
|
||||||
explicit VectorInferTensorWrapReply(std::vector<InferTensor> &tensor_list) : tensor_list_(tensor_list) {}
|
|
||||||
~VectorInferTensorWrapReply() = default;
|
|
||||||
|
|
||||||
size_t size() const { return tensor_list_.size(); }
|
|
||||||
InferTensorBase *operator[](size_t index) {
|
|
||||||
if (index >= tensor_list_.size()) {
|
|
||||||
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << tensor_list_.size();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return &(tensor_list_[index]);
|
|
||||||
}
|
|
||||||
const InferTensorBase *operator[](size_t index) const {
|
|
||||||
if (index >= tensor_list_.size()) {
|
|
||||||
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << tensor_list_.size();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return &(tensor_list_[index]);
|
|
||||||
}
|
|
||||||
InferTensorBase *add() {
|
|
||||||
tensor_list_.push_back(InferTensor());
|
|
||||||
return &(tensor_list_.back());
|
|
||||||
}
|
|
||||||
void clear() { tensor_list_.clear(); }
|
|
||||||
std::vector<InferTensor> &tensor_list_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class VectorInferTensorWrapRequest : public RequestBase {
|
|
||||||
public:
|
|
||||||
explicit VectorInferTensorWrapRequest(const std::vector<InferTensor> &tensor_list) : tensor_list_(tensor_list) {}
|
|
||||||
~VectorInferTensorWrapRequest() = default;
|
|
||||||
|
|
||||||
size_t size() const { return tensor_list_.size(); }
|
|
||||||
const InferTensorBase *operator[](size_t index) const {
|
|
||||||
if (index >= tensor_list_.size()) {
|
|
||||||
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << tensor_list_.size();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return &(tensor_list_[index]);
|
|
||||||
}
|
|
||||||
const std::vector<InferTensor> &tensor_list_;
|
|
||||||
};
|
|
||||||
} // namespace inference
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_INCLUDE_INFER_TENSOR_H_
|
|
|
@ -1,86 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_INCLUDE_MS_SESSION_H
|
|
||||||
#define MINDSPORE_INCLUDE_MS_SESSION_H
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include "include/infer_tensor.h"
|
|
||||||
#include "include/infer_log.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace inference {
|
|
||||||
enum StatusCode { SUCCESS = 0, FAILED, INVALID_INPUTS };
|
|
||||||
|
|
||||||
class Status {
|
|
||||||
public:
|
|
||||||
Status() : status_code_(FAILED) {}
|
|
||||||
Status(enum StatusCode status_code, const std::string &status_msg = "")
|
|
||||||
: status_code_(status_code), status_msg_(status_msg) {}
|
|
||||||
~Status() = default;
|
|
||||||
|
|
||||||
bool IsSuccess() const { return status_code_ == SUCCESS; }
|
|
||||||
enum StatusCode StatusCode() const { return status_code_; }
|
|
||||||
std::string StatusMessage() const { return status_msg_; }
|
|
||||||
bool operator==(const Status &other) const { return status_code_ == other.status_code_; }
|
|
||||||
bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; }
|
|
||||||
bool operator!=(const Status &other) const { return status_code_ != other.status_code_; }
|
|
||||||
bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; }
|
|
||||||
operator bool() const = delete;
|
|
||||||
Status &operator<(const LogStream &stream) noexcept __attribute__((visibility("default"))) {
|
|
||||||
status_msg_ = stream.sstream_->str();
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
enum StatusCode status_code_;
|
|
||||||
std::string status_msg_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class MS_API InferSession {
|
|
||||||
public:
|
|
||||||
InferSession() = default;
|
|
||||||
virtual ~InferSession() = default;
|
|
||||||
virtual Status InitEnv(const std::string &device_type, uint32_t device_id) = 0;
|
|
||||||
virtual Status FinalizeEnv() = 0;
|
|
||||||
virtual Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) = 0;
|
|
||||||
virtual Status UnloadModel(uint32_t model_id) = 0;
|
|
||||||
// override this method to avoid request/reply data copy
|
|
||||||
virtual Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) = 0;
|
|
||||||
|
|
||||||
virtual Status ExecuteModel(uint32_t model_id, const std::vector<InferTensor> &inputs,
|
|
||||||
std::vector<InferTensor> &outputs) {
|
|
||||||
VectorInferTensorWrapRequest request(inputs);
|
|
||||||
VectorInferTensorWrapReply reply(outputs);
|
|
||||||
return ExecuteModel(model_id, request, reply);
|
|
||||||
}
|
|
||||||
// default not support input data preprocess(decode, resize, crop, crop&paste, etc.)
|
|
||||||
virtual Status ExecuteModel(uint32_t /*model_id*/,
|
|
||||||
const ImagesRequestBase & /*images_inputs*/, // images for preprocess
|
|
||||||
const RequestBase & /*request*/, ReplyBase & /*reply*/) {
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
virtual Status GetModelInputsInfo(uint32_t graph_id, std::vector<inference::InferTensor> *tensor_list) const {
|
|
||||||
Status status(SUCCESS);
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
static std::shared_ptr<InferSession> CreateSession(const std::string &device, uint32_t device_id);
|
|
||||||
};
|
|
||||||
} // namespace inference
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_INCLUDE_MS_SESSION_H
|
|
|
@ -21,12 +21,19 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); }
|
std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); }
|
||||||
|
|
||||||
ParameterCell::ParameterCell(const ParameterCell &cell) : tensor_(cell.tensor_.Clone()) {}
|
ParameterCell::ParameterCell(const ParameterCell &cell) {
|
||||||
|
auto tmp_ptr = cell.tensor_.Clone();
|
||||||
|
tensor_ = *tmp_ptr;
|
||||||
|
MSTensor::DestroyTensorPtr(tmp_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
ParameterCell &ParameterCell::operator=(const ParameterCell &cell) {
|
ParameterCell &ParameterCell::operator=(const ParameterCell &cell) {
|
||||||
if (&cell == this) {
|
if (&cell == this) {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
tensor_ = cell.tensor_.Clone();
|
auto tmp_ptr = cell.tensor_.Clone();
|
||||||
|
tensor_ = *tmp_ptr;
|
||||||
|
MSTensor::DestroyTensorPtr(tmp_ptr);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,10 +47,16 @@ ParameterCell &ParameterCell::operator=(ParameterCell &&cell) {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
ParameterCell::ParameterCell(const MSTensor &tensor) : tensor_(tensor.Clone()) {}
|
ParameterCell::ParameterCell(const MSTensor &tensor) {
|
||||||
|
auto tmp_ptr = tensor.Clone();
|
||||||
|
tensor_ = *tmp_ptr;
|
||||||
|
MSTensor::DestroyTensorPtr(tmp_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
ParameterCell &ParameterCell::operator=(const MSTensor &tensor) {
|
ParameterCell &ParameterCell::operator=(const MSTensor &tensor) {
|
||||||
tensor_ = tensor.Clone();
|
auto tmp_ptr = tensor.Clone();
|
||||||
|
tensor_ = *tmp_ptr;
|
||||||
|
MSTensor::DestroyTensorPtr(tmp_ptr);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,54 +67,67 @@ ParameterCell &ParameterCell::operator=(MSTensor &&tensor) {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphCell::GraphCell(const Graph &graph)
|
GraphCell::GraphCell(const Graph &graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); }
|
||||||
: graph_(std::make_shared<Graph>(graph)),
|
|
||||||
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
|
|
||||||
MS_EXCEPTION_IF_NULL(graph_);
|
|
||||||
MS_EXCEPTION_IF_NULL(executor_);
|
|
||||||
executor_->SetGraph(graph_);
|
|
||||||
}
|
|
||||||
|
|
||||||
GraphCell::GraphCell(const std::shared_ptr<Graph> &graph)
|
GraphCell::GraphCell(const std::shared_ptr<Graph> &graph) : graph_(graph) { MS_EXCEPTION_IF_NULL(graph_); }
|
||||||
: graph_(graph),
|
|
||||||
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
|
|
||||||
MS_EXCEPTION_IF_NULL(graph_);
|
|
||||||
MS_EXCEPTION_IF_NULL(executor_);
|
|
||||||
executor_->SetGraph(graph_);
|
|
||||||
}
|
|
||||||
|
|
||||||
GraphCell::GraphCell(Graph &&graph)
|
GraphCell::GraphCell(Graph &&graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); }
|
||||||
: graph_(std::make_shared<Graph>(graph)),
|
|
||||||
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
|
|
||||||
MS_EXCEPTION_IF_NULL(graph_);
|
|
||||||
MS_EXCEPTION_IF_NULL(executor_);
|
|
||||||
executor_->SetGraph(graph_);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||||
MS_EXCEPTION_IF_NULL(executor_);
|
if (executor_ == nullptr) {
|
||||||
|
executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
|
||||||
|
if (executor_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
|
||||||
|
return kMEFailed;
|
||||||
|
}
|
||||||
|
executor_->SetGraph(graph_);
|
||||||
|
}
|
||||||
return executor_->Run(inputs, outputs);
|
return executor_->Run(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GraphCell::Load() {
|
Status GraphCell::Load(uint32_t device_id) {
|
||||||
MS_EXCEPTION_IF_NULL(executor_);
|
if (executor_ == nullptr) {
|
||||||
return executor_->Load();
|
executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
|
||||||
|
if (executor_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
|
||||||
|
return kMEFailed;
|
||||||
|
}
|
||||||
|
executor_->SetGraph(graph_);
|
||||||
|
}
|
||||||
|
return executor_->Load(device_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<MSTensor> GraphCell::GetInputs() {
|
std::vector<MSTensor> GraphCell::GetInputs() {
|
||||||
MS_EXCEPTION_IF_NULL(executor_);
|
if (executor_ == nullptr) {
|
||||||
|
executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
|
||||||
|
if (executor_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
executor_->SetGraph(graph_);
|
||||||
|
}
|
||||||
return executor_->GetInputs();
|
return executor_->GetInputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<MSTensor> GraphCell::GetOutputs() {
|
std::vector<MSTensor> GraphCell::GetOutputs() {
|
||||||
MS_EXCEPTION_IF_NULL(executor_);
|
if (executor_ == nullptr) {
|
||||||
|
executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
|
||||||
|
if (executor_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
executor_->SetGraph(graph_);
|
||||||
|
}
|
||||||
return executor_->GetOutputs();
|
return executor_->GetOutputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {}
|
InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {}
|
||||||
|
|
||||||
InputAndOutput::InputAndOutput(const MSTensor &tensor)
|
InputAndOutput::InputAndOutput(const MSTensor &tensor) : prev_(), index_(-1) {
|
||||||
: cell_(std::make_shared<ParameterCell>(tensor.Clone())), prev_(), index_(-1) {}
|
auto tmp_ptr = tensor.Clone();
|
||||||
|
cell_ = std::make_shared<ParameterCell>(*tmp_ptr);
|
||||||
|
MSTensor::DestroyTensorPtr(tmp_ptr);
|
||||||
|
}
|
||||||
InputAndOutput::InputAndOutput(MSTensor &&tensor)
|
InputAndOutput::InputAndOutput(MSTensor &&tensor)
|
||||||
: cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {}
|
: cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {}
|
||||||
|
|
||||||
|
|
|
@ -17,41 +17,57 @@
|
||||||
#include <any>
|
#include <any>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
#include "cxx_api/factory.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target";
|
constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
|
||||||
constexpr auto kGlobalContextDeviceID = "mindspore.ascend.globalcontext.device_id";
|
constexpr auto kModelOptionCpuThreadAffinity = "mindspore.option.cpu.thread_affinity";
|
||||||
constexpr auto kGlobalContextDumpCfgPath = "mindspore.ascend.globalcontext.dump_config_file_path";
|
constexpr auto kModelOptionMaliGpuEnableFP16 = "mindspore.option.mali_gpu.enable_fp16";
|
||||||
constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
|
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
|
||||||
constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc
|
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
|
||||||
constexpr auto kModelOptionInputShapeMap = "mindspore.option.input_shape_map";
|
constexpr auto kModelOptionNvidiaGpuDeviceID = kModelOptionDeviceID;
|
||||||
constexpr auto kModelOptionInputShape = "mindspore.option.input_shape";
|
constexpr auto kModelOptionNvidiaGpuTrtInferMode = "mindspore.option.nvidia_gpu.trt_infer_mode";
|
||||||
|
constexpr auto kModelOptionAscend910DeviceID = kModelOptionDeviceID;
|
||||||
|
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
|
||||||
|
constexpr auto kModelOptionAscend310DumpCfgPath = "mindspore.option.ascend310.dump_config_file_path";
|
||||||
|
constexpr auto kModelOptionAscend310InsertOpCfgPath =
|
||||||
|
"mindspore.option.ascend310.insert_op_config_file_path"; // aipp config file
|
||||||
|
constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format"; // nchw or nhwc
|
||||||
|
constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map";
|
||||||
|
constexpr auto kModelOptionAscend310InputShape = "mindspore.option.ascend310.input_shape";
|
||||||
// Mandatory while dynamic batch: e.g. "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"
|
// Mandatory while dynamic batch: e.g. "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"
|
||||||
constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP32", "UINT8" or "FP16", default as "FP32"
|
constexpr auto kModelOptionAscend310OutputType =
|
||||||
constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
|
"mindspore.option.ascend310.output_type"; // "FP32", "UINT8" or "FP16", default as "FP32"
|
||||||
|
constexpr auto kModelOptionAscend310PrecisionMode = "mindspore.option.ascend310.precision_mode";
|
||||||
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
|
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
|
||||||
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
|
constexpr auto kModelOptionAscend310OpSelectImplMode = "mindspore.option.ascend310.op_select_impl_mode";
|
||||||
constexpr auto KModelOptionFusionSwitchCfgPath = "mindspore.option.fusion_switch_config_file_path";
|
constexpr auto KModelOptionAscend310FusionSwitchCfgPath = "mindspore.option.ascend310.fusion_switch_config_file_path";
|
||||||
// "False": Inference with native backend, "True": Inference with Tensor-RT engine, default as "False"
|
// "False": Inference with native backend, "True": Inference with Tensor-RT engine, default as "False"
|
||||||
constexpr auto kModelOptionGpuTrtInferMode = "mindspore.option.gpu_trt_infer_mode";
|
constexpr auto kModelOptionAscend310DynamicBatchSize = "mindspore.option.ascend310.dynamic_batch_size";
|
||||||
constexpr auto kModelOptionDynamicBatchSize = "mindspore.option.dynamic_batch_size";
|
|
||||||
constexpr auto kModelOptionDynamicImageSize = "mindspore.option.dynamic_image_size";
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
class Allocator {};
|
||||||
|
|
||||||
struct Context::Data {
|
struct Context::Data {
|
||||||
|
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
|
||||||
|
int32_t thread_num;
|
||||||
|
std::shared_ptr<Allocator> allocator;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DeviceInfoContext::Data {
|
||||||
std::map<std::string, std::any> params;
|
std::map<std::string, std::any> params;
|
||||||
};
|
};
|
||||||
|
|
||||||
Context::Context() : data(std::make_shared<Data>()) {}
|
Context::Context() : data_(std::make_shared<Data>()) {}
|
||||||
|
|
||||||
template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
|
template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
|
||||||
static const U &GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
|
static const U &GetValue(const std::shared_ptr<DeviceInfoContext::Data> &data, const std::string &key) {
|
||||||
static U empty_result;
|
static U empty_result;
|
||||||
if (context == nullptr || context->data == nullptr) {
|
if (data == nullptr) {
|
||||||
return empty_result;
|
return empty_result;
|
||||||
}
|
}
|
||||||
auto iter = context->data->params.find(key);
|
auto iter = data->params.find(key);
|
||||||
if (iter == context->data->params.end()) {
|
if (iter == data->params.end()) {
|
||||||
return empty_result;
|
return empty_result;
|
||||||
}
|
}
|
||||||
const std::any &value = iter->second;
|
const std::any &value = iter->second;
|
||||||
|
@ -62,210 +78,205 @@ static const U &GetValue(const std::shared_ptr<Context> &context, const std::str
|
||||||
return std::any_cast<const U &>(value);
|
return std::any_cast<const U &>(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
|
void Context::SetThreadNum(int32_t thread_num) {
|
||||||
static std::shared_ptr<Context> g_context = std::make_shared<Context>();
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
return g_context;
|
data_->thread_num = thread_num;
|
||||||
|
}
|
||||||
|
int32_t Context::GetThreadNum() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
return data_->thread_num;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlobalContext::SetGlobalDeviceTarget(const std::vector<char> &device_target) {
|
void Context::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
|
||||||
auto global_context = GetGlobalContext();
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
MS_EXCEPTION_IF_NULL(global_context);
|
data_->allocator = allocator;
|
||||||
if (global_context->data == nullptr) {
|
}
|
||||||
global_context->data = std::make_shared<Data>();
|
std::shared_ptr<Allocator> Context::GetAllocator() const {
|
||||||
MS_EXCEPTION_IF_NULL(global_context->data);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
}
|
return data_->allocator;
|
||||||
global_context->data->params[kGlobalContextDeviceTarget] = CharToString(device_target);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> GlobalContext::GetGlobalDeviceTargetChar() {
|
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
|
||||||
auto global_context = GetGlobalContext();
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
MS_EXCEPTION_IF_NULL(global_context);
|
return data_->device_info_list;
|
||||||
const std::string &ref = GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
|
}
|
||||||
|
|
||||||
|
DeviceInfoContext::DeviceInfoContext() : data_(std::make_shared<Data>()) {}
|
||||||
|
|
||||||
|
void CPUDeviceInfo::SetEnableFP16(bool is_fp16) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
data_->params[kModelOptionCpuEnableFP16] = is_fp16;
|
||||||
|
}
|
||||||
|
bool CPUDeviceInfo::GetEnableFP16() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
return GetValue<bool>(data_, kModelOptionCpuEnableFP16);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CPUDeviceInfo::SetThreadAffinity(int affinity) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
data_->params[kModelOptionCpuThreadAffinity] = affinity;
|
||||||
|
}
|
||||||
|
int CPUDeviceInfo::GetThreadAffinity() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
return GetValue<bool>(data_, kModelOptionCpuThreadAffinity);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MaliGPUDeviceInfo::SetEnableFP16(bool is_fp16) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
data_->params[kModelOptionMaliGpuEnableFP16] = is_fp16;
|
||||||
|
}
|
||||||
|
bool MaliGPUDeviceInfo::GetEnableFP16() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
return GetValue<bool>(data_, kModelOptionMaliGpuEnableFP16);
|
||||||
|
}
|
||||||
|
|
||||||
|
void KirinNPUDeviceInfo::SetFrequency(int frequency) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
data_->params[kModelOptionKirinNpuFrequency] = frequency;
|
||||||
|
}
|
||||||
|
int KirinNPUDeviceInfo::GetFrequency() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
return GetValue<int>(data_, kModelOptionKirinNpuFrequency);
|
||||||
|
}
|
||||||
|
|
||||||
|
void NvidiaGPUDeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
data_->params[kModelOptionNvidiaGpuDeviceID] = device_id;
|
||||||
|
}
|
||||||
|
uint32_t NvidiaGPUDeviceInfo::GetDeviceID() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
return GetValue<uint32_t>(data_, kModelOptionNvidiaGpuDeviceID);
|
||||||
|
}
|
||||||
|
|
||||||
|
void NvidiaGPUDeviceInfo::SetGpuTrtInferMode(bool gpu_trt_infer_mode) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
data_->params[kModelOptionNvidiaGpuTrtInferMode] = gpu_trt_infer_mode;
|
||||||
|
}
|
||||||
|
bool NvidiaGPUDeviceInfo::GetGpuTrtInferMode() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
return GetValue<bool>(data_, kModelOptionNvidiaGpuTrtInferMode);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
data_->params[kModelOptionAscend910DeviceID] = device_id;
|
||||||
|
}
|
||||||
|
uint32_t Ascend910DeviceInfo::GetDeviceID() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
return GetValue<uint32_t>(data_, kModelOptionAscend910DeviceID);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
data_->params[kModelOptionAscend310DeviceID] = device_id;
|
||||||
|
}
|
||||||
|
uint32_t Ascend310DeviceInfo::GetDeviceID() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetDumpConfigPath(const std::vector<char> &cfg_path) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
data_->params[kModelOptionAscend310DumpCfgPath] = CharToString(cfg_path);
|
||||||
|
}
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetDumpConfigPathChar() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DeviceID);
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) {
|
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
|
||||||
auto global_context = GetGlobalContext();
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
MS_EXCEPTION_IF_NULL(global_context);
|
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
|
||||||
if (global_context->data == nullptr) {
|
|
||||||
global_context->data = std::make_shared<Data>();
|
|
||||||
MS_EXCEPTION_IF_NULL(global_context->data);
|
|
||||||
}
|
|
||||||
global_context->data->params[kGlobalContextDeviceID] = device_id;
|
|
||||||
}
|
}
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
|
||||||
uint32_t GlobalContext::GetGlobalDeviceID() {
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
auto global_context = GetGlobalContext();
|
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InsertOpCfgPath);
|
||||||
MS_EXCEPTION_IF_NULL(global_context);
|
|
||||||
return GetValue<uint32_t>(global_context, kGlobalContextDeviceID);
|
|
||||||
}
|
|
||||||
|
|
||||||
void GlobalContext::SetGlobalDumpConfigPath(const std::vector<char> &cfg_path) {
|
|
||||||
auto global_context = GetGlobalContext();
|
|
||||||
MS_EXCEPTION_IF_NULL(global_context);
|
|
||||||
if (global_context->data == nullptr) {
|
|
||||||
global_context->data = std::make_shared<Data>();
|
|
||||||
MS_EXCEPTION_IF_NULL(global_context->data);
|
|
||||||
}
|
|
||||||
global_context->data->params[kGlobalContextDumpCfgPath] = CharToString(cfg_path);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<char> GlobalContext::GetGlobalDumpConfigPathChar() {
|
|
||||||
auto global_context = GetGlobalContext();
|
|
||||||
MS_EXCEPTION_IF_NULL(global_context);
|
|
||||||
const std::string &ref = GetValue<std::string>(global_context, kGlobalContextDumpCfgPath);
|
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path) {
|
void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
if (context->data == nullptr) {
|
data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
|
||||||
context->data = std::make_shared<Data>();
|
|
||||||
MS_EXCEPTION_IF_NULL(context->data);
|
|
||||||
}
|
|
||||||
context->data->params[kModelOptionInsertOpCfgPath] = CharToString(cfg_path);
|
|
||||||
}
|
}
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
|
||||||
std::vector<char> ModelContext::GetInsertOpConfigPathChar(const std::shared_ptr<Context> &context) {
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputFormat);
|
||||||
const std::string &ref = GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
|
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::vector<char> &format) {
|
void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) {
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
if (context->data == nullptr) {
|
data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
|
||||||
context->data = std::make_shared<Data>();
|
|
||||||
MS_EXCEPTION_IF_NULL(context->data);
|
|
||||||
}
|
|
||||||
context->data->params[kModelOptionInputFormat] = CharToString(format);
|
|
||||||
}
|
}
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
|
||||||
std::vector<char> ModelContext::GetInputFormatChar(const std::shared_ptr<Context> &context) {
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputShape);
|
||||||
const std::string &ref = GetValue<std::string>(context, kModelOptionInputFormat);
|
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::vector<char> &shape) {
|
void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
if (context->data == nullptr) {
|
|
||||||
context->data = std::make_shared<Data>();
|
|
||||||
MS_EXCEPTION_IF_NULL(context->data);
|
|
||||||
}
|
|
||||||
context->data->params[kModelOptionInputShape] = CharToString(shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<char> ModelContext::GetInputShapeChar(const std::shared_ptr<Context> &context) {
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
const std::string &ref = GetValue<std::string>(context, kModelOptionInputShape);
|
|
||||||
return StringToChar(ref);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ModelContext::SetInputShapeMap(const std::shared_ptr<Context> &context,
|
|
||||||
const std::map<int, std::vector<int>> &shape) {
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
context->data->params[kModelOptionInputShapeMap] = shape;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::map<int, std::vector<int>> ModelContext::GetInputShapeMap(const std::shared_ptr<Context> &context) {
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
return GetValue<std::map<int, std::vector<int>>>(context, kModelOptionInputShapeMap);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ModelContext::SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type) {
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
if (context->data == nullptr) {
|
|
||||||
context->data = std::make_shared<Data>();
|
|
||||||
MS_EXCEPTION_IF_NULL(context->data);
|
|
||||||
}
|
|
||||||
context->data->params[kModelOptionOutputType] = output_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &context) {
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
return GetValue<enum DataType>(context, kModelOptionOutputType);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::vector<char> &precision_mode) {
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
if (context->data == nullptr) {
|
|
||||||
context->data = std::make_shared<Data>();
|
|
||||||
MS_EXCEPTION_IF_NULL(context->data);
|
|
||||||
}
|
|
||||||
context->data->params[kModelOptionPrecisionMode] = CharToString(precision_mode);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<char> ModelContext::GetPrecisionModeChar(const std::shared_ptr<Context> &context) {
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
const std::string &ref = GetValue<std::string>(context, kModelOptionPrecisionMode);
|
|
||||||
return StringToChar(ref);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
|
|
||||||
const std::vector<char> &op_select_impl_mode) {
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
if (context->data == nullptr) {
|
|
||||||
context->data = std::make_shared<Data>();
|
|
||||||
MS_EXCEPTION_IF_NULL(context->data);
|
|
||||||
}
|
|
||||||
context->data->params[kModelOptionOpSelectImplMode] = CharToString(op_select_impl_mode);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<char> ModelContext::GetOpSelectImplModeChar(const std::shared_ptr<Context> &context) {
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
const std::string &ref = GetValue<std::string>(context, kModelOptionOpSelectImplMode);
|
|
||||||
return StringToChar(ref);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ModelContext::SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context,
|
|
||||||
const std::vector<char> &cfg_path) {
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
if (context->data == nullptr) {
|
|
||||||
context->data = std::make_shared<Data>();
|
|
||||||
MS_EXCEPTION_IF_NULL(context->data);
|
|
||||||
}
|
|
||||||
context->data->params[KModelOptionFusionSwitchCfgPath] = CharToString(cfg_path);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<char> ModelContext::GetFusionSwitchConfigPathChar(const std::shared_ptr<Context> &context) {
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
const std::string &ref = GetValue<std::string>(context, KModelOptionFusionSwitchCfgPath);
|
|
||||||
return StringToChar(ref);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ModelContext::SetGpuTrtInferMode(const std::shared_ptr<Context> &context,
|
|
||||||
const std::vector<char> &gpu_trt_infer_mode) {
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
if (context->data == nullptr) {
|
|
||||||
context->data = std::make_shared<Data>();
|
|
||||||
MS_EXCEPTION_IF_NULL(context->data);
|
|
||||||
}
|
|
||||||
context->data->params[kModelOptionGpuTrtInferMode] = CharToString(gpu_trt_infer_mode);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<char> ModelContext::GetGpuTrtInferModeChar(const std::shared_ptr<Context> &context) {
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
const std::string &ref = GetValue<std::string>(context, kModelOptionGpuTrtInferMode);
|
|
||||||
return StringToChar(ref);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ModelContext::SetDynamicBatchSize(const std::shared_ptr<Context> &context, const std::vector<size_t> &batch_size) {
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
if (context->data == nullptr) {
|
|
||||||
context->data = std::make_shared<Data>();
|
|
||||||
MS_EXCEPTION_IF_NULL(context->data);
|
|
||||||
}
|
|
||||||
std::string batchs = "";
|
std::string batchs = "";
|
||||||
for (auto bs : batch_size) {
|
for (size_t i = 0; i < dynamic_batch_size.size(); ++i) {
|
||||||
batchs += std::to_string(bs) + ",";
|
if (i != 0) {
|
||||||
|
batchs.push_back(',');
|
||||||
|
}
|
||||||
|
batchs += std::to_string(dynamic_batch_size[i]);
|
||||||
}
|
}
|
||||||
context->data->params[kModelOptionDynamicBatchSize] = batchs;
|
data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
|
||||||
|
}
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicBatchSize);
|
||||||
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> ModelContext::GetDynamicBatchSizeChar(const std::shared_ptr<Context> &context) {
|
void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
const std::string &ref = GetValue<std::string>(context, kModelOptionDynamicBatchSize);
|
data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
|
||||||
|
}
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310PrecisionMode);
|
||||||
return StringToChar(ref);
|
return StringToChar(ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
|
||||||
|
}
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310OpSelectImplMode);
|
||||||
|
return StringToChar(ref);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
|
||||||
|
}
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
const std::string &ref = GetValue<std::string>(data_, KModelOptionAscend310FusionSwitchCfgPath);
|
||||||
|
return StringToChar(ref);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
data_->params[kModelOptionAscend310InputShapeMap] = shape;
|
||||||
|
}
|
||||||
|
std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
data_->params[kModelOptionAscend310OutputType] = output_type;
|
||||||
|
}
|
||||||
|
enum DataType Ascend310DeviceInfo::GetOutputType() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -24,6 +24,8 @@
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
inline std::string g_device_target = "Default";
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
class Factory {
|
class Factory {
|
||||||
using U = std::function<std::shared_ptr<T>()>;
|
using U = std::function<std::shared_ptr<T>()>;
|
||||||
|
|
|
@ -45,6 +45,9 @@ std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv(std::string_view cfg_file) {
|
||||||
acl_env = global_acl_env_;
|
acl_env = global_acl_env_;
|
||||||
if (acl_env != nullptr) {
|
if (acl_env != nullptr) {
|
||||||
MS_LOG(INFO) << "Acl has been initialized, skip.";
|
MS_LOG(INFO) << "Acl has been initialized, skip.";
|
||||||
|
if (!cfg_file.empty()) {
|
||||||
|
MS_LOG(WARNING) << "Dump config file option " << cfg_file << " is ignored.";
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
acl_env = std::make_shared<AclEnvGuard>(cfg_file);
|
acl_env = std::make_shared<AclEnvGuard>(cfg_file);
|
||||||
aclError ret = acl_env->GetErrno();
|
aclError ret = acl_env->GetErrno();
|
||||||
|
|
|
@ -25,7 +25,7 @@ AclGraphImpl::AclGraphImpl()
|
||||||
: init_flag_(false),
|
: init_flag_(false),
|
||||||
load_flag_(false),
|
load_flag_(false),
|
||||||
device_type_("AscendCL"),
|
device_type_("AscendCL"),
|
||||||
device_id_(GlobalContext::GetGlobalDeviceID()),
|
device_id_(0),
|
||||||
context_(nullptr),
|
context_(nullptr),
|
||||||
acl_env_(nullptr) {}
|
acl_env_(nullptr) {}
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ AclGraphImpl::~AclGraphImpl() { (void)FinalizeEnv(); }
|
||||||
|
|
||||||
Status AclGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
Status AclGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||||
MS_EXCEPTION_IF_NULL(outputs);
|
MS_EXCEPTION_IF_NULL(outputs);
|
||||||
Status ret = Load();
|
Status ret = Load(device_id_);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "Prepare model resource failed.";
|
MS_LOG(ERROR) << "Prepare model resource failed.";
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -43,7 +43,7 @@ Status AclGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTens
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<MSTensor> AclGraphImpl::GetInputs() {
|
std::vector<MSTensor> AclGraphImpl::GetInputs() {
|
||||||
Status ret = Load();
|
Status ret = Load(device_id_);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "Prepare model resource failed.";
|
MS_LOG(ERROR) << "Prepare model resource failed.";
|
||||||
return {};
|
return {};
|
||||||
|
@ -53,7 +53,7 @@ std::vector<MSTensor> AclGraphImpl::GetInputs() {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<MSTensor> AclGraphImpl::GetOutputs() {
|
std::vector<MSTensor> AclGraphImpl::GetOutputs() {
|
||||||
Status ret = Load();
|
Status ret = Load(device_id_);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "Prepare model resource failed.";
|
MS_LOG(ERROR) << "Prepare model resource failed.";
|
||||||
return {};
|
return {};
|
||||||
|
@ -90,7 +90,7 @@ Status AclGraphImpl::InitEnv() {
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
acl_env_ = AclEnvGuard::GetAclEnv(GlobalContext::GetGlobalDumpConfigPath());
|
acl_env_ = AclEnvGuard::GetAclEnv("");
|
||||||
if (acl_env_ == nullptr) {
|
if (acl_env_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Acl init failed.";
|
MS_LOG(ERROR) << "Acl init failed.";
|
||||||
return kMCDeviceError;
|
return kMCDeviceError;
|
||||||
|
@ -161,7 +161,7 @@ Status AclGraphImpl::FinalizeEnv() {
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AclGraphImpl::Load() {
|
Status AclGraphImpl::Load(uint32_t device_id) {
|
||||||
// check graph type
|
// check graph type
|
||||||
if (graph_->ModelType() != ModelType::kOM) {
|
if (graph_->ModelType() != ModelType::kOM) {
|
||||||
Status ret = ConvertToOM();
|
Status ret = ConvertToOM();
|
||||||
|
@ -176,6 +176,7 @@ Status AclGraphImpl::Load() {
|
||||||
auto om_data = graph_data->GetOMData();
|
auto om_data = graph_data->GetOMData();
|
||||||
|
|
||||||
// init
|
// init
|
||||||
|
device_id_ = device_id;
|
||||||
Status ret = InitEnv();
|
Status ret = InitEnv();
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "InitEnv failed.";
|
MS_LOG(ERROR) << "InitEnv failed.";
|
||||||
|
|
|
@ -34,7 +34,7 @@ class AclGraphImpl : public GraphCell::GraphImpl {
|
||||||
~AclGraphImpl() override;
|
~AclGraphImpl() override;
|
||||||
|
|
||||||
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
||||||
Status Load() override;
|
Status Load(uint32_t device_id) override;
|
||||||
std::vector<MSTensor> GetInputs() override;
|
std::vector<MSTensor> GetInputs() override;
|
||||||
std::vector<MSTensor> GetOutputs() override;
|
std::vector<MSTensor> GetOutputs() override;
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ AscendGraphImpl::AscendGraphImpl()
|
||||||
: session_impl_(nullptr),
|
: session_impl_(nullptr),
|
||||||
graph_id_(0),
|
graph_id_(0),
|
||||||
device_type_("Ascend"),
|
device_type_("Ascend"),
|
||||||
device_id_(GlobalContext::GetGlobalDeviceID()),
|
device_id_(0),
|
||||||
context_(nullptr),
|
context_(nullptr),
|
||||||
inputs_info_(),
|
inputs_info_(),
|
||||||
outputs_info_(),
|
outputs_info_(),
|
||||||
|
@ -142,7 +142,7 @@ Status AscendGraphImpl::ExecuteModel(const std::vector<MSTensor> &request, std::
|
||||||
|
|
||||||
std::vector<MSTensor> AscendGraphImpl::GetInputs() {
|
std::vector<MSTensor> AscendGraphImpl::GetInputs() {
|
||||||
if (!load_flag_) {
|
if (!load_flag_) {
|
||||||
Status ret = Load();
|
Status ret = Load(device_id_);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "PrepareModel failed.";
|
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||||
return {};
|
return {};
|
||||||
|
@ -166,7 +166,7 @@ std::vector<MSTensor> AscendGraphImpl::GetInputs() {
|
||||||
|
|
||||||
std::vector<MSTensor> AscendGraphImpl::GetOutputs() {
|
std::vector<MSTensor> AscendGraphImpl::GetOutputs() {
|
||||||
if (!load_flag_) {
|
if (!load_flag_) {
|
||||||
Status ret = Load();
|
Status ret = Load(device_id_);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "PrepareModel failed.";
|
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||||
return {};
|
return {};
|
||||||
|
@ -188,7 +188,7 @@ std::vector<MSTensor> AscendGraphImpl::GetOutputs() {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AscendGraphImpl::Load() {
|
Status AscendGraphImpl::Load(uint32_t device_id) {
|
||||||
// check graph type
|
// check graph type
|
||||||
if (graph_->ModelType() != ModelType::kMindIR) {
|
if (graph_->ModelType() != ModelType::kMindIR) {
|
||||||
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
|
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
|
||||||
|
@ -200,6 +200,7 @@ Status AscendGraphImpl::Load() {
|
||||||
auto func_graph = graph_data->GetFuncGraph();
|
auto func_graph = graph_data->GetFuncGraph();
|
||||||
|
|
||||||
// init
|
// init
|
||||||
|
device_id_ = device_id;
|
||||||
Status ret = InitEnv();
|
Status ret = InitEnv();
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "InitEnv failed.";
|
MS_LOG(ERROR) << "InitEnv failed.";
|
||||||
|
@ -247,7 +248,7 @@ Status AscendGraphImpl::Load() {
|
||||||
Status AscendGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
Status AscendGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||||
MS_EXCEPTION_IF_NULL(outputs);
|
MS_EXCEPTION_IF_NULL(outputs);
|
||||||
if (!load_flag_) {
|
if (!load_flag_) {
|
||||||
Status ret = Load();
|
Status ret = Load(device_id_);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "PrepareModel failed.";
|
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||||
return ret;
|
return ret;
|
||||||
|
|
|
@ -36,7 +36,7 @@ class AscendGraphImpl : public GraphCell::GraphImpl {
|
||||||
~AscendGraphImpl() override;
|
~AscendGraphImpl() override;
|
||||||
|
|
||||||
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
||||||
Status Load() override;
|
Status Load(uint32_t device_id) override;
|
||||||
std::vector<MSTensor> GetInputs() override;
|
std::vector<MSTensor> GetInputs() override;
|
||||||
std::vector<MSTensor> GetOutputs() override;
|
std::vector<MSTensor> GetOutputs() override;
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl);
|
||||||
GPUGraphImpl::GPUGraphImpl()
|
GPUGraphImpl::GPUGraphImpl()
|
||||||
: session_impl_(nullptr),
|
: session_impl_(nullptr),
|
||||||
graph_id_(0),
|
graph_id_(0),
|
||||||
device_id_(GlobalContext::GetGlobalDeviceID()),
|
device_id_(0),
|
||||||
inputs_info_(),
|
inputs_info_(),
|
||||||
outputs_info_(),
|
outputs_info_(),
|
||||||
input_names_(),
|
input_names_(),
|
||||||
|
@ -83,7 +83,7 @@ Status GPUGraphImpl::FinalizeEnv() {
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GPUGraphImpl::Load() {
|
Status GPUGraphImpl::Load(uint32_t device_id) {
|
||||||
// check graph type
|
// check graph type
|
||||||
if (graph_->ModelType() != ModelType::kMindIR) {
|
if (graph_->ModelType() != ModelType::kMindIR) {
|
||||||
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
|
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
|
||||||
|
@ -95,6 +95,7 @@ Status GPUGraphImpl::Load() {
|
||||||
auto func_graph = graph_data->GetFuncGraph();
|
auto func_graph = graph_data->GetFuncGraph();
|
||||||
|
|
||||||
// init
|
// init
|
||||||
|
device_id_ = device_id;
|
||||||
Status ret = InitEnv();
|
Status ret = InitEnv();
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "InitEnv failed.";
|
MS_LOG(ERROR) << "InitEnv failed.";
|
||||||
|
@ -176,7 +177,7 @@ Status GPUGraphImpl::ExecuteModel(const std::vector<MSTensor> &request, std::vec
|
||||||
Status GPUGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
Status GPUGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||||
MS_EXCEPTION_IF_NULL(outputs);
|
MS_EXCEPTION_IF_NULL(outputs);
|
||||||
if (!load_flag_) {
|
if (!load_flag_) {
|
||||||
Status ret = Load();
|
Status ret = Load(device_id_);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "PrepareModel failed.";
|
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -211,7 +212,7 @@ Status GPUGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTens
|
||||||
|
|
||||||
std::vector<MSTensor> GPUGraphImpl::GetInputs() {
|
std::vector<MSTensor> GPUGraphImpl::GetInputs() {
|
||||||
if (!load_flag_) {
|
if (!load_flag_) {
|
||||||
Status ret = Load();
|
Status ret = Load(device_id_);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "PrepareModel failed.";
|
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||||
return {};
|
return {};
|
||||||
|
@ -235,7 +236,7 @@ std::vector<MSTensor> GPUGraphImpl::GetInputs() {
|
||||||
|
|
||||||
std::vector<MSTensor> GPUGraphImpl::GetOutputs() {
|
std::vector<MSTensor> GPUGraphImpl::GetOutputs() {
|
||||||
if (!load_flag_) {
|
if (!load_flag_) {
|
||||||
Status ret = Load();
|
Status ret = Load(device_id_);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "PrepareModel failed.";
|
MS_LOG(ERROR) << "PrepareModel failed.";
|
||||||
return {};
|
return {};
|
||||||
|
|
|
@ -33,7 +33,7 @@ class GPUGraphImpl : public GraphCell::GraphImpl {
|
||||||
~GPUGraphImpl() override = default;
|
~GPUGraphImpl() override = default;
|
||||||
|
|
||||||
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
||||||
Status Load() override;
|
Status Load(uint32_t device_id) override;
|
||||||
std::vector<MSTensor> GetInputs() override;
|
std::vector<MSTensor> GetInputs() override;
|
||||||
std::vector<MSTensor> GetOutputs() override;
|
std::vector<MSTensor> GetOutputs() override;
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,8 @@
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
Graph::Graph() : graph_data_(nullptr) {}
|
||||||
|
|
||||||
Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_data) {}
|
Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_data) {}
|
||||||
|
|
||||||
Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {}
|
Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {}
|
||||||
|
@ -28,6 +30,8 @@ Graph::Graph(std::nullptr_t) : graph_data_(nullptr) {}
|
||||||
|
|
||||||
bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; }
|
bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; }
|
||||||
|
|
||||||
|
bool Graph::operator!=(std::nullptr_t) const { return graph_data_ != nullptr; }
|
||||||
|
|
||||||
ModelType Graph::ModelType() const {
|
ModelType Graph::ModelType() const {
|
||||||
MS_EXCEPTION_IF_NULL(graph_data_);
|
MS_EXCEPTION_IF_NULL(graph_data_);
|
||||||
return graph_data_->ModelType();
|
return graph_data_->ModelType();
|
||||||
|
|
|
@ -36,7 +36,7 @@ class GraphCell::GraphImpl {
|
||||||
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
|
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
|
||||||
|
|
||||||
virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0;
|
virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0;
|
||||||
virtual Status Load() = 0;
|
virtual Status Load(uint32_t device_id) = 0;
|
||||||
|
|
||||||
virtual std::vector<MSTensor> GetInputs() = 0;
|
virtual std::vector<MSTensor> GetInputs() = 0;
|
||||||
virtual std::vector<MSTensor> GetOutputs() = 0;
|
virtual std::vector<MSTensor> GetOutputs() = 0;
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "include/api/context.h"
|
#include "include/api/context.h"
|
||||||
#include "cxx_api/factory.h"
|
#include "cxx_api/factory.h"
|
||||||
|
#include "cxx_api/graph/acl/acl_env_guard.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
API_FACTORY_REG(ModelImpl, Ascend310, AclModel);
|
API_FACTORY_REG(ModelImpl, Ascend310, AclModel);
|
||||||
|
@ -40,6 +41,11 @@ Status AclModel::Build() {
|
||||||
|
|
||||||
std::unique_ptr<AclModelOptions> options = std::make_unique<AclModelOptions>(model_context_);
|
std::unique_ptr<AclModelOptions> options = std::make_unique<AclModelOptions>(model_context_);
|
||||||
MS_EXCEPTION_IF_NULL(options);
|
MS_EXCEPTION_IF_NULL(options);
|
||||||
|
std::string dump_cfg = options->GetDumpCfgPath();
|
||||||
|
if (!dump_cfg.empty()) {
|
||||||
|
MS_LOG(INFO) << "Options dump config file path " << dump_cfg;
|
||||||
|
(void)AclEnvGuard::GetAclEnv(dump_cfg);
|
||||||
|
}
|
||||||
std::string options_key = options->GenAclOptionsKey();
|
std::string options_key = options->GenAclOptionsKey();
|
||||||
std::shared_ptr<Graph> graph;
|
std::shared_ptr<Graph> graph;
|
||||||
if (auto iter = dynamic_size_graph_map_.find(options_key); iter != dynamic_size_graph_map_.end()) {
|
if (auto iter = dynamic_size_graph_map_.find(options_key); iter != dynamic_size_graph_map_.end()) {
|
||||||
|
@ -75,7 +81,7 @@ Status AclModel::Build() {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
auto graph_cell = std::make_shared<GraphCell>(graph);
|
auto graph_cell = std::make_shared<GraphCell>(graph);
|
||||||
MS_EXCEPTION_IF_NULL(graph_cell);
|
MS_EXCEPTION_IF_NULL(graph_cell);
|
||||||
auto ret = ModelImpl::Load(graph_cell);
|
auto ret = ModelImpl::Load(graph_cell, options->GetDeviceID());
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "Load failed.";
|
MS_LOG(ERROR) << "Load failed.";
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -108,7 +114,8 @@ Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<s
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model_context_ == nullptr) {
|
if (model_context_ == nullptr) {
|
||||||
model_context_ = std::make_shared<ModelContext>();
|
model_context_ = std::make_shared<Context>();
|
||||||
|
model_context_->MutableDeviceInfo().emplace_back(std::make_shared<Ascend310DeviceInfo>());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string input_shape_option;
|
std::string input_shape_option;
|
||||||
|
@ -130,7 +137,14 @@ Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Set input size option is " << input_shape_option;
|
MS_LOG(INFO) << "Set input size option is " << input_shape_option;
|
||||||
ModelContext::SetInputShape(model_context_, input_shape_option);
|
auto &device_infos = model_context_->MutableDeviceInfo();
|
||||||
|
if (device_infos.size() != 1) {
|
||||||
|
MS_LOG(ERROR) << "Invalid model context, only single device info is supported.";
|
||||||
|
return kMCInvalidArgs;
|
||||||
|
}
|
||||||
|
auto ascend310_info = device_infos[0]->Cast<Ascend310DeviceInfo>();
|
||||||
|
MS_EXCEPTION_IF_NULL(ascend310_info);
|
||||||
|
ascend310_info->SetInputShape(input_shape_option);
|
||||||
auto graph_cell_bak = std::move(graph_cell_);
|
auto graph_cell_bak = std::move(graph_cell_);
|
||||||
auto ret = Build();
|
auto ret = Build();
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
|
|
|
@ -27,10 +27,19 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
|
||||||
if (context == nullptr) {
|
if (context == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
insert_op_cfg_path_ = ModelContext::GetInsertOpConfigPath(context);
|
auto &device_infos = context->MutableDeviceInfo();
|
||||||
input_format_ = ModelContext::GetInputFormat(context);
|
if (device_infos.size() != 1) {
|
||||||
input_shape_map_ = ModelContext::GetInputShapeMap(context);
|
return;
|
||||||
auto out_type = ModelContext::GetOutputType(context);
|
}
|
||||||
|
auto ascend310_info = device_infos[0]->Cast<Ascend310DeviceInfo>();
|
||||||
|
if (ascend310_info == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
insert_op_cfg_path_ = ascend310_info->GetInsertOpConfigPath();
|
||||||
|
input_format_ = ascend310_info->GetInputFormat();
|
||||||
|
input_shape_map_ = ascend310_info->GetInputShapeMap();
|
||||||
|
auto out_type = ascend310_info->GetOutputType();
|
||||||
auto iter = kSupportedDtypeOptionMap.find(out_type);
|
auto iter = kSupportedDtypeOptionMap.find(out_type);
|
||||||
if (out_type == DataType::kTypeUnknown) {
|
if (out_type == DataType::kTypeUnknown) {
|
||||||
// do nothing
|
// do nothing
|
||||||
|
@ -39,10 +48,12 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
|
||||||
} else {
|
} else {
|
||||||
output_type_ = iter->second;
|
output_type_ = iter->second;
|
||||||
}
|
}
|
||||||
dynamic_batch_size_ = ModelContext::GetDynamicBatchSize(context);
|
dynamic_batch_size_ = ascend310_info->GetDynamicBatchSize();
|
||||||
precision_mode_ = ModelContext::GetPrecisionMode(context);
|
precision_mode_ = ascend310_info->GetPrecisionMode();
|
||||||
op_select_impl_mode_ = ModelContext::GetOpSelectImplMode(context);
|
op_select_impl_mode_ = ascend310_info->GetOpSelectImplMode();
|
||||||
fusion_switch_cfg_path_ = ModelContext::GetFusionSwitchConfigPath(context);
|
fusion_switch_cfg_path_ = ascend310_info->GetFusionSwitchConfigPath();
|
||||||
|
device_id_ = ascend310_info->GetDeviceID();
|
||||||
|
dump_cfg_path_ = ascend310_info->GetDumpConfigPath();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AclModelOptions::RenameInput(const std::vector<std::string> &input_names) {
|
void AclModelOptions::RenameInput(const std::vector<std::string> &input_names) {
|
||||||
|
|
|
@ -31,6 +31,8 @@ class AclModelOptions {
|
||||||
explicit AclModelOptions(const std::shared_ptr<Context> &context);
|
explicit AclModelOptions(const std::shared_ptr<Context> &context);
|
||||||
~AclModelOptions() = default;
|
~AclModelOptions() = default;
|
||||||
std::string GenAclOptionsKey() const;
|
std::string GenAclOptionsKey() const;
|
||||||
|
uint32_t GetDeviceID() const { return device_id_; }
|
||||||
|
std::string GetDumpCfgPath() const { return dump_cfg_path_; }
|
||||||
void RenameInput(const std::vector<std::string> &);
|
void RenameInput(const std::vector<std::string> &);
|
||||||
|
|
||||||
// return tuple<init_options, build_options>
|
// return tuple<init_options, build_options>
|
||||||
|
@ -50,7 +52,9 @@ class AclModelOptions {
|
||||||
std::string dynamic_batch_size_;
|
std::string dynamic_batch_size_;
|
||||||
std::string dynamic_image_size_;
|
std::string dynamic_image_size_;
|
||||||
std::map<int, std::vector<int>> input_shape_map_;
|
std::map<int, std::vector<int>> input_shape_map_;
|
||||||
std::vector<std::string> dynamic_image_size_nums_;
|
// other options
|
||||||
|
uint32_t device_id_;
|
||||||
|
std::string dump_cfg_path_;
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -21,60 +21,130 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace {
|
namespace {
|
||||||
const std::map<std::string, std::set<ModelType>> kSupportedModelMap = {
|
const std::map<enum DeviceType, std::set<ModelType>> kSupportedModelMap = {
|
||||||
{kDeviceTypeAscend310, {kOM, kMindIR}},
|
{kAscend310, {kOM, kMindIR}},
|
||||||
{kDeviceTypeAscend910, {kMindIR}},
|
{kAscend910, {kMindIR}},
|
||||||
{kDeviceTypeGPU, {kMindIR}},
|
{kNvidiaGPU, {kMindIR}},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::string GetDeviceTypeString(enum DeviceType type) {
|
||||||
|
static const std::map<enum DeviceType, std::string> kDeviceTypeStrs = {
|
||||||
|
{kCPU, "CPU"}, {kMaliGPU, "MaliGPU"}, {kNvidiaGPU, "GPU"},
|
||||||
|
{kKirinNPU, "KirinGPU"}, {kAscend910, "Ascend910"}, {kAscend310, "Ascend310"},
|
||||||
|
};
|
||||||
|
auto iter = kDeviceTypeStrs.find(type);
|
||||||
|
if (iter != kDeviceTypeStrs.end()) {
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
return "InvalidDeviceType" + std::to_string(type);
|
||||||
}
|
}
|
||||||
Status Model::Build() {
|
} // namespace
|
||||||
MS_EXCEPTION_IF_NULL(impl_);
|
Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_context) {
|
||||||
|
if (graph_cell.GetGraph() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid graph input.";
|
||||||
|
return kMCInvalidInput;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model_context == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid model context.";
|
||||||
|
return kMCInvalidInput;
|
||||||
|
}
|
||||||
|
auto &device_info = model_context->MutableDeviceInfo();
|
||||||
|
if (device_info.size() != 1) {
|
||||||
|
MS_LOG(ERROR) << "Invalid model context, only single device info is supported.";
|
||||||
|
return kMCInvalidInput;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string device_target = GetDeviceTypeString(device_info[0]->GetDeviceType());
|
||||||
|
impl_ = Factory<ModelImpl>::Instance().Create(device_target);
|
||||||
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create session type " << device_target << " failed";
|
||||||
|
return kMEFailed;
|
||||||
|
}
|
||||||
|
|
||||||
|
g_device_target = device_target;
|
||||||
|
|
||||||
|
impl_->SetGraph(std::make_shared<Graph>(*graph_cell.GetGraph()));
|
||||||
|
impl_->SetContext(model_context);
|
||||||
|
|
||||||
return impl_->Build();
|
return impl_->Build();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
|
Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
|
||||||
MS_EXCEPTION_IF_NULL(impl_);
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed because this model has not been built.";
|
||||||
|
return kMCFailed;
|
||||||
|
}
|
||||||
return impl_->Resize(inputs, dims);
|
return impl_->Resize(inputs, dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||||
MS_EXCEPTION_IF_NULL(impl_);
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed because this model has not been built.";
|
||||||
|
return kMCFailed;
|
||||||
|
}
|
||||||
return impl_->Predict(inputs, outputs);
|
return impl_->Predict(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<MSTensor> Model::GetInputs() {
|
std::vector<MSTensor> Model::GetInputs() {
|
||||||
MS_EXCEPTION_IF_NULL(impl_);
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed because this model has not been built.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
return impl_->GetInputs();
|
return impl_->GetInputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<MSTensor> Model::GetOutputs() {
|
std::vector<MSTensor> Model::GetOutputs() {
|
||||||
MS_EXCEPTION_IF_NULL(impl_);
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed because this model has not been built.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
return impl_->GetOutputs();
|
return impl_->GetOutputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
Model::Model(const GraphCell &graph_cell, const std::shared_ptr<Context> &model_context)
|
MSTensor Model::GetInputByTensorName(const std::vector<char> &tensor_name) {
|
||||||
: impl_(Factory<ModelImpl>::Instance().Create(mindspore::GlobalContext::GetGlobalDeviceTarget())) {
|
std::string tensor_name_str = CharToString(tensor_name);
|
||||||
if (impl_ == nullptr) {
|
auto inputs = GetInputs();
|
||||||
MS_LOG(EXCEPTION) << "Create session type " << mindspore::GlobalContext::GetGlobalDeviceTarget() << " failed";
|
for (auto in : inputs) {
|
||||||
|
if (in.Name() == tensor_name_str) {
|
||||||
|
return in;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(graph_cell.GetGraph());
|
|
||||||
impl_->SetGraph(std::make_shared<Graph>(*graph_cell.GetGraph()));
|
return MSTensor(std::shared_ptr<MSTensor::Impl>(nullptr));
|
||||||
impl_->SetContext(model_context);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context> &model_context) {
|
std::vector<std::vector<char>> Model::GetOutputTensorNamesChar() {
|
||||||
MS_LOG(EXCEPTION) << "Unsupported feature.";
|
std::vector<std::vector<char>> ret;
|
||||||
|
auto outputs = GetOutputs();
|
||||||
|
std::transform(outputs.begin(), outputs.end(), std::back_inserter(ret),
|
||||||
|
[](MSTensor item) -> std::vector<char> { return StringToChar(item.Name()); });
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MSTensor Model::GetOutputByTensorName(const std::vector<char> &tensor_name) {
|
||||||
|
std::string tensor_name_str = CharToString(tensor_name);
|
||||||
|
auto outputs = GetOutputs();
|
||||||
|
for (auto out : outputs) {
|
||||||
|
if (out.Name() == tensor_name_str) {
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return MSTensor(std::shared_ptr<MSTensor::Impl>(nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
Model::Model() : impl_(nullptr) {}
|
||||||
Model::~Model() {}
|
Model::~Model() {}
|
||||||
|
|
||||||
bool Model::CheckModelSupport(const std::vector<char> &device_type, ModelType model_type) {
|
bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) {
|
||||||
std::string device_type_str = CharToString(device_type);
|
std::string device_type_str = GetDeviceTypeString(device_type);
|
||||||
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type_str)) {
|
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type_str)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto first_iter = kSupportedModelMap.find(device_type_str);
|
auto first_iter = kSupportedModelMap.find(device_type);
|
||||||
if (first_iter == kSupportedModelMap.end()) {
|
if (first_iter == kSupportedModelMap.end()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,9 +43,9 @@ class ModelImpl {
|
||||||
virtual std::vector<MSTensor> GetOutputs() = 0;
|
virtual std::vector<MSTensor> GetOutputs() = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status Load(const std::shared_ptr<GraphCell> &graph_cell) {
|
Status Load(const std::shared_ptr<GraphCell> &graph_cell, uint32_t device_id) {
|
||||||
MS_EXCEPTION_IF_NULL(graph_cell);
|
MS_EXCEPTION_IF_NULL(graph_cell);
|
||||||
return graph_cell->Load();
|
return graph_cell->Load(device_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr GetFuncGraph() const {
|
FuncGraphPtr GetFuncGraph() const {
|
||||||
|
|
|
@ -74,7 +74,7 @@ std::shared_ptr<GraphCell> MsModel::GenerateGraphCell(const std::vector<std::vec
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
auto graph_cell = std::make_shared<GraphCell>(graph);
|
auto graph_cell = std::make_shared<GraphCell>(graph);
|
||||||
MS_EXCEPTION_IF_NULL(graph_cell);
|
MS_EXCEPTION_IF_NULL(graph_cell);
|
||||||
auto ret = ModelImpl::Load(graph_cell);
|
auto ret = ModelImpl::Load(graph_cell, GetDeviceID());
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "Load failed.";
|
MS_LOG(ERROR) << "Load failed.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -99,7 +99,7 @@ Status MsModel::Build() {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
auto graph_cell = std::make_shared<GraphCell>(graph);
|
auto graph_cell = std::make_shared<GraphCell>(graph);
|
||||||
MS_EXCEPTION_IF_NULL(graph_cell);
|
MS_EXCEPTION_IF_NULL(graph_cell);
|
||||||
auto ret = ModelImpl::Load(graph_cell);
|
auto ret = ModelImpl::Load(graph_cell, GetDeviceID());
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "Load failed.";
|
MS_LOG(ERROR) << "Load failed.";
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -170,4 +170,27 @@ std::vector<MSTensor> MsModel::GetOutputs() {
|
||||||
MS_EXCEPTION_IF_NULL(graph_cell_);
|
MS_EXCEPTION_IF_NULL(graph_cell_);
|
||||||
return graph_cell_->GetOutputs();
|
return graph_cell_->GetOutputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t MsModel::GetDeviceID() const {
|
||||||
|
if (model_context_ == nullptr) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &device_infos = model_context_->MutableDeviceInfo();
|
||||||
|
if (device_infos.size() != 1) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ascend910_info = device_infos[0]->Cast<Ascend910DeviceInfo>();
|
||||||
|
if (ascend910_info != nullptr) {
|
||||||
|
return ascend910_info->GetDeviceID();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gpu_info = device_infos[0]->Cast<NvidiaGPUDeviceInfo>();
|
||||||
|
if (gpu_info != nullptr) {
|
||||||
|
return gpu_info->GetDeviceID();
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -48,6 +48,7 @@ class MsModel : public ModelImpl {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<GraphCell> GenerateGraphCell(const std::vector<std::vector<int64_t>> &dims);
|
std::shared_ptr<GraphCell> GenerateGraphCell(const std::vector<std::vector<int64_t>> &dims);
|
||||||
|
uint32_t GetDeviceID() const;
|
||||||
|
|
||||||
std::shared_ptr<GraphCell> graph_cell_;
|
std::shared_ptr<GraphCell> graph_cell_;
|
||||||
std::map<std::string, std::shared_ptr<GraphCell>> dynamic_size_graph_map_;
|
std::map<std::string, std::shared_ptr<GraphCell>> dynamic_size_graph_map_;
|
||||||
|
|
|
@ -68,38 +68,59 @@ static Buffer ReadFile(const std::string &file) {
|
||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelType model_type) {
|
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph) {
|
||||||
|
if (graph == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Output args graph is nullptr.";
|
||||||
|
return kMEInvalidInput;
|
||||||
|
}
|
||||||
|
|
||||||
if (model_type == kMindIR) {
|
if (model_type == kMindIR) {
|
||||||
FuncGraphPtr anf_graph = nullptr;
|
FuncGraphPtr anf_graph = nullptr;
|
||||||
try {
|
try {
|
||||||
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data), data_size);
|
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data), data_size);
|
||||||
} catch (const std::exception &) {
|
} catch (const std::exception &) {
|
||||||
MS_LOG(EXCEPTION) << "Load MindIR failed.";
|
MS_LOG(ERROR) << "Load model failed.";
|
||||||
|
return kMEInvalidInput;
|
||||||
}
|
}
|
||||||
|
|
||||||
return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
|
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
|
||||||
|
return kSuccess;
|
||||||
} else if (model_type == kOM) {
|
} else if (model_type == kOM) {
|
||||||
return Graph(std::make_shared<Graph::GraphData>(Buffer(model_data, data_size), kOM));
|
*graph = Graph(std::make_shared<Graph::GraphData>(Buffer(model_data, data_size), kOM));
|
||||||
|
return kSuccess;
|
||||||
}
|
}
|
||||||
MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;
|
|
||||||
|
MS_LOG(ERROR) << "Unsupported ModelType " << model_type;
|
||||||
|
return kMEInvalidInput;
|
||||||
}
|
}
|
||||||
|
|
||||||
Graph Serialization::LoadModel(const std::vector<char> &file, ModelType model_type) {
|
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
|
||||||
|
if (graph == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Output args graph is nullptr.";
|
||||||
|
return kMEInvalidInput;
|
||||||
|
}
|
||||||
|
|
||||||
std::string file_path = CharToString(file);
|
std::string file_path = CharToString(file);
|
||||||
if (model_type == kMindIR) {
|
if (model_type == kMindIR) {
|
||||||
FuncGraphPtr anf_graph = LoadMindIR(file_path);
|
FuncGraphPtr anf_graph = LoadMindIR(file_path);
|
||||||
if (anf_graph == nullptr) {
|
if (anf_graph == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Load model failed.";
|
MS_LOG(ERROR) << "Load model failed.";
|
||||||
|
return kMEInvalidInput;
|
||||||
}
|
}
|
||||||
return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
|
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
|
||||||
|
return kSuccess;
|
||||||
} else if (model_type == kOM) {
|
} else if (model_type == kOM) {
|
||||||
Buffer data = ReadFile(file_path);
|
Buffer data = ReadFile(file_path);
|
||||||
if (data.Data() == nullptr) {
|
if (data.Data() == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Read file " << file_path << " failed.";
|
MS_LOG(ERROR) << "Read file " << file_path << " failed.";
|
||||||
|
return kMEInvalidInput;
|
||||||
}
|
}
|
||||||
return Graph(std::make_shared<Graph::GraphData>(data, kOM));
|
*graph = Graph(std::make_shared<Graph::GraphData>(data, kOM));
|
||||||
|
return kSuccess;
|
||||||
}
|
}
|
||||||
MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;
|
|
||||||
|
MS_LOG(ERROR) << "Unsupported ModelType " << model_type;
|
||||||
|
return kMEInvalidInput;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
|
Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
|
||||||
|
|
|
@ -134,33 +134,139 @@ class TensorReferenceImpl : public MSTensor::Impl {
|
||||||
std::vector<int64_t> shape_;
|
std::vector<int64_t> shape_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MSTensor MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
MSTensor *MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||||
const void *data, size_t data_len) noexcept {
|
const void *data, size_t data_len) noexcept {
|
||||||
std::string name_str = CharToString(name);
|
std::string name_str = CharToString(name);
|
||||||
try {
|
try {
|
||||||
std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name_str, type, shape, data, data_len);
|
std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name_str, type, shape, data, data_len);
|
||||||
return MSTensor(impl);
|
MSTensor *ret = new MSTensor(impl);
|
||||||
|
return ret;
|
||||||
} catch (const std::bad_alloc &) {
|
} catch (const std::bad_alloc &) {
|
||||||
MS_LOG(ERROR) << "Malloc memory failed.";
|
MS_LOG(ERROR) << "Malloc memory failed.";
|
||||||
return MSTensor(nullptr);
|
return nullptr;
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
MS_LOG(ERROR) << "Unknown error occurred.";
|
MS_LOG(ERROR) << "Unknown error occurred.";
|
||||||
return MSTensor(nullptr);
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MSTensor MSTensor::CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
MSTensor *MSTensor::CreateRefTensor(const std::vector<char> &name, enum DataType type,
|
||||||
const void *data, size_t data_len) noexcept {
|
const std::vector<int64_t> &shape, const void *data, size_t data_len) noexcept {
|
||||||
std::string name_str = CharToString(name);
|
std::string name_str = CharToString(name);
|
||||||
try {
|
try {
|
||||||
std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name_str, type, shape, data, data_len);
|
std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name_str, type, shape, data, data_len);
|
||||||
return MSTensor(impl);
|
MSTensor *ret = new MSTensor(impl);
|
||||||
|
return ret;
|
||||||
} catch (const std::bad_alloc &) {
|
} catch (const std::bad_alloc &) {
|
||||||
MS_LOG(ERROR) << "Malloc memory failed.";
|
MS_LOG(ERROR) << "Malloc memory failed.";
|
||||||
return MSTensor(nullptr);
|
return nullptr;
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
MS_LOG(ERROR) << "Unknown error occurred.";
|
MS_LOG(ERROR) << "Unknown error occurred.";
|
||||||
return MSTensor(nullptr);
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MSTensor *MSTensor::CharStringsToTensor(const std::vector<char> &name, const std::vector<std::vector<char>> &str) {
|
||||||
|
// num(4 bytes) + offset1(4 bytes) + offset2(4 bytes) + ... + data1(str1.len) + data2(str2.len) + ...
|
||||||
|
// str1.len() = offset2 - offset1
|
||||||
|
// data1.begin() = start + offset1
|
||||||
|
size_t mem_size = 0;
|
||||||
|
mem_size += sizeof(int32_t); // for num
|
||||||
|
for (const auto &s : str) {
|
||||||
|
mem_size += sizeof(int32_t); // for offset
|
||||||
|
mem_size += s.size(); // for data
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tensor = CreateTensor(name, DataType::kObjectTypeString, {static_cast<int64_t>(mem_size)}, nullptr, mem_size);
|
||||||
|
if (tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t *data = reinterpret_cast<int32_t *>(tensor->MutableData());
|
||||||
|
if (data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor failed.";
|
||||||
|
DestroyTensorPtr(tensor);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
uint8_t *cur_data = reinterpret_cast<uint8_t *>(data + 1 + str.size());
|
||||||
|
*reinterpret_cast<int32_t *>(data) = str.size();
|
||||||
|
for (size_t i = 0; i < str.size(); ++i) {
|
||||||
|
int32_t offset = (cur_data - reinterpret_cast<uint8_t *>(data));
|
||||||
|
data[i + 1] = offset;
|
||||||
|
if (str[i].empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto ret = memcpy_s(reinterpret_cast<void *>(cur_data), str[i].size(), str[i].data(), str[i].size());
|
||||||
|
if (ret != 0) {
|
||||||
|
MS_LOG(ERROR) << "memcpy_s failed, ret = " << ret;
|
||||||
|
DestroyTensorPtr(tensor);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
cur_data += str[i].size();
|
||||||
|
}
|
||||||
|
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<char>> MSTensor::TensorToStringChars(const MSTensor &tensor) {
|
||||||
|
if (tensor == nullptr || tensor.DataType() != DataType::kObjectTypeString || tensor.DataSize() < 4) {
|
||||||
|
MS_LOG(ERROR) << "Invalid tensor.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<char>> strings;
|
||||||
|
auto host_data = tensor.Data();
|
||||||
|
const int32_t *data = reinterpret_cast<const int32_t *>(host_data.get());
|
||||||
|
int32_t str_num = data[0];
|
||||||
|
if (str_num == 0) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
if (str_num < 0) {
|
||||||
|
MS_LOG(ERROR) << "str num " << str_num << " cannot be negative.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tensor.DataSize() < (str_num + 1) * sizeof(int32_t)) {
|
||||||
|
MS_LOG(ERROR) << "Invalid tensor data size " << tensor.DataSize() << ", need " << (str_num + 1) * sizeof(int32_t)
|
||||||
|
<< " at least for " << str_num << " strings.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < static_cast<size_t>(str_num); ++i) {
|
||||||
|
strings.push_back({});
|
||||||
|
auto &str = strings[i];
|
||||||
|
int32_t str_len;
|
||||||
|
int32_t offset = data[i + 1];
|
||||||
|
if (i + 1 != static_cast<size_t>(str_num)) {
|
||||||
|
str_len = data[i + 1 + 1] - offset;
|
||||||
|
} else {
|
||||||
|
str_len = tensor.DataSize() - offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (str_len == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (str_len < 0) {
|
||||||
|
MS_LOG(ERROR) << "str " << i << " len " << str_len << " cannot be negative.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
str.resize(str_len);
|
||||||
|
const uint8_t *cur_data = reinterpret_cast<const uint8_t *>(data) + offset;
|
||||||
|
auto ret = memcpy_s(reinterpret_cast<void *>(str.data()), str.size(), cur_data, str_len);
|
||||||
|
if (ret != 0) {
|
||||||
|
MS_LOG(ERROR) << "memcpy_s failed, ret = " << ret;
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MSTensor::DestroyTensorPtr(MSTensor *tensor) noexcept {
|
||||||
|
if (tensor != nullptr) {
|
||||||
|
delete tensor;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,11 +280,21 @@ MSTensor::~MSTensor() = default;
|
||||||
|
|
||||||
bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; }
|
bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; }
|
||||||
|
|
||||||
MSTensor MSTensor::Clone() const {
|
bool MSTensor::operator!=(std::nullptr_t) const { return impl_ != nullptr; }
|
||||||
|
|
||||||
|
MSTensor *MSTensor::Clone() const {
|
||||||
MS_EXCEPTION_IF_NULL(impl_);
|
MS_EXCEPTION_IF_NULL(impl_);
|
||||||
MSTensor ret;
|
try {
|
||||||
ret.impl_ = impl_->Clone();
|
MSTensor *ret = new MSTensor();
|
||||||
return ret;
|
ret->impl_ = impl_->Clone();
|
||||||
|
return ret;
|
||||||
|
} catch (const std::bad_alloc &) {
|
||||||
|
MS_LOG(ERROR) << "Malloc memory failed.";
|
||||||
|
return nullptr;
|
||||||
|
} catch (...) {
|
||||||
|
MS_LOG(ERROR) << "Unknown error occurred.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> MSTensor::CharName() const {
|
std::vector<char> MSTensor::CharName() const {
|
||||||
|
|
|
@ -55,14 +55,14 @@ struct Execute::ExtraInfo {
|
||||||
};
|
};
|
||||||
|
|
||||||
// FIXME - Temporarily overload Execute to support both TensorOperation and TensorTransform
|
// FIXME - Temporarily overload Execute to support both TensorOperation and TensorTransform
|
||||||
Execute::Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice deviceType) {
|
Execute::Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice deviceType, uint32_t device_id) {
|
||||||
ops_.emplace_back(std::move(op));
|
ops_.emplace_back(std::move(op));
|
||||||
device_type_ = deviceType;
|
device_type_ = deviceType;
|
||||||
info_ = std::make_shared<ExtraInfo>();
|
info_ = std::make_shared<ExtraInfo>();
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
device_resource_ = std::make_shared<AscendResource>();
|
||||||
Status rc = device_resource_->InitResource();
|
Status rc = device_resource_->InitResource(device_id);
|
||||||
if (!rc.IsOk()) {
|
if (!rc.IsOk()) {
|
||||||
device_resource_ = nullptr;
|
device_resource_ = nullptr;
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
||||||
|
@ -71,7 +71,7 @@ Execute::Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice deviceType
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice deviceType) {
|
Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice deviceType, uint32_t device_id) {
|
||||||
// Initialize the op and other context
|
// Initialize the op and other context
|
||||||
transforms_.emplace_back(op);
|
transforms_.emplace_back(op);
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice deviceType
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
device_resource_ = std::make_shared<AscendResource>();
|
||||||
Status rc = device_resource_->InitResource();
|
Status rc = device_resource_->InitResource(device_id);
|
||||||
if (!rc.IsOk()) {
|
if (!rc.IsOk()) {
|
||||||
device_resource_ = nullptr;
|
device_resource_ = nullptr;
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
||||||
|
@ -89,7 +89,7 @@ Execute::Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice deviceType
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice deviceType) {
|
Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice deviceType, uint32_t device_id) {
|
||||||
// Initialize the transforms_ and other context
|
// Initialize the transforms_ and other context
|
||||||
std::shared_ptr<TensorOperation> operation = op.get().Parse();
|
std::shared_ptr<TensorOperation> operation = op.get().Parse();
|
||||||
ops_.emplace_back(std::move(operation));
|
ops_.emplace_back(std::move(operation));
|
||||||
|
@ -100,7 +100,7 @@ Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice dev
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
device_resource_ = std::make_shared<AscendResource>();
|
||||||
Status rc = device_resource_->InitResource();
|
Status rc = device_resource_->InitResource(device_id);
|
||||||
if (!rc.IsOk()) {
|
if (!rc.IsOk()) {
|
||||||
device_resource_ = nullptr;
|
device_resource_ = nullptr;
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
||||||
|
@ -110,7 +110,7 @@ Execute::Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice dev
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute function for the example case: auto decode(new vision::Decode());
|
// Execute function for the example case: auto decode(new vision::Decode());
|
||||||
Execute::Execute(TensorTransform *op, MapTargetDevice deviceType) {
|
Execute::Execute(TensorTransform *op, MapTargetDevice deviceType, uint32_t device_id) {
|
||||||
// Initialize the transforms_ and other context
|
// Initialize the transforms_ and other context
|
||||||
std::shared_ptr<TensorTransform> smart_ptr_op(op);
|
std::shared_ptr<TensorTransform> smart_ptr_op(op);
|
||||||
transforms_.emplace_back(smart_ptr_op);
|
transforms_.emplace_back(smart_ptr_op);
|
||||||
|
@ -120,7 +120,7 @@ Execute::Execute(TensorTransform *op, MapTargetDevice deviceType) {
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
device_resource_ = std::make_shared<AscendResource>();
|
||||||
Status rc = device_resource_->InitResource();
|
Status rc = device_resource_->InitResource(device_id);
|
||||||
if (!rc.IsOk()) {
|
if (!rc.IsOk()) {
|
||||||
device_resource_ = nullptr;
|
device_resource_ = nullptr;
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
||||||
|
@ -129,13 +129,13 @@ Execute::Execute(TensorTransform *op, MapTargetDevice deviceType) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(std::vector<std::shared_ptr<TensorOperation>> ops, MapTargetDevice deviceType)
|
Execute::Execute(std::vector<std::shared_ptr<TensorOperation>> ops, MapTargetDevice deviceType, uint32_t device_id)
|
||||||
: ops_(std::move(ops)), device_type_(deviceType) {
|
: ops_(std::move(ops)), device_type_(deviceType) {
|
||||||
info_ = std::make_shared<ExtraInfo>();
|
info_ = std::make_shared<ExtraInfo>();
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
device_resource_ = std::make_shared<AscendResource>();
|
||||||
Status rc = device_resource_->InitResource();
|
Status rc = device_resource_->InitResource(device_id);
|
||||||
if (!rc.IsOk()) {
|
if (!rc.IsOk()) {
|
||||||
device_resource_ = nullptr;
|
device_resource_ = nullptr;
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
||||||
|
@ -144,7 +144,7 @@ Execute::Execute(std::vector<std::shared_ptr<TensorOperation>> ops, MapTargetDev
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDevice deviceType) {
|
Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDevice deviceType, uint32_t device_id) {
|
||||||
// Initialize the transforms_ and other context
|
// Initialize the transforms_ and other context
|
||||||
transforms_ = ops;
|
transforms_ = ops;
|
||||||
|
|
||||||
|
@ -153,7 +153,7 @@ Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDev
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
device_resource_ = std::make_shared<AscendResource>();
|
||||||
Status rc = device_resource_->InitResource();
|
Status rc = device_resource_->InitResource(device_id);
|
||||||
if (!rc.IsOk()) {
|
if (!rc.IsOk()) {
|
||||||
device_resource_ = nullptr;
|
device_resource_ = nullptr;
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
||||||
|
@ -162,7 +162,8 @@ Execute::Execute(std::vector<std::shared_ptr<TensorTransform>> ops, MapTargetDev
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops, MapTargetDevice deviceType) {
|
Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops, MapTargetDevice deviceType,
|
||||||
|
uint32_t device_id) {
|
||||||
// Initialize the transforms_ and other context
|
// Initialize the transforms_ and other context
|
||||||
if (deviceType == MapTargetDevice::kCpu) {
|
if (deviceType == MapTargetDevice::kCpu) {
|
||||||
(void)std::transform(
|
(void)std::transform(
|
||||||
|
@ -180,7 +181,7 @@ Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops,
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
device_resource_ = std::make_shared<AscendResource>();
|
||||||
Status rc = device_resource_->InitResource();
|
Status rc = device_resource_->InitResource(device_id);
|
||||||
if (!rc.IsOk()) {
|
if (!rc.IsOk()) {
|
||||||
device_resource_ = nullptr;
|
device_resource_ = nullptr;
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
||||||
|
@ -190,7 +191,7 @@ Execute::Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute function for the example vector case: auto decode(new vision::Decode());
|
// Execute function for the example vector case: auto decode(new vision::Decode());
|
||||||
Execute::Execute(std::vector<TensorTransform *> ops, MapTargetDevice deviceType) {
|
Execute::Execute(std::vector<TensorTransform *> ops, MapTargetDevice deviceType, uint32_t device_id) {
|
||||||
// Initialize the transforms_ and other context
|
// Initialize the transforms_ and other context
|
||||||
for (auto &op : ops) {
|
for (auto &op : ops) {
|
||||||
std::shared_ptr<TensorTransform> smart_ptr_op(op);
|
std::shared_ptr<TensorTransform> smart_ptr_op(op);
|
||||||
|
@ -202,7 +203,7 @@ Execute::Execute(std::vector<TensorTransform *> ops, MapTargetDevice deviceType)
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
if (device_type_ == MapTargetDevice::kAscend310) {
|
if (device_type_ == MapTargetDevice::kAscend310) {
|
||||||
device_resource_ = std::make_shared<AscendResource>();
|
device_resource_ = std::make_shared<AscendResource>();
|
||||||
Status rc = device_resource_->InitResource();
|
Status rc = device_resource_->InitResource(device_id);
|
||||||
if (!rc.IsOk()) {
|
if (!rc.IsOk()) {
|
||||||
device_resource_ = nullptr;
|
device_resource_ = nullptr;
|
||||||
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
MS_LOG(ERROR) << "Initialize Ascend310 resource fail";
|
||||||
|
|
|
@ -23,10 +23,10 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
Status AscendResource::InitResource() {
|
Status AscendResource::InitResource(uint32_t device_id) {
|
||||||
ResourceInfo resource;
|
ResourceInfo resource;
|
||||||
resource.aclConfigPath = "";
|
resource.aclConfigPath = "";
|
||||||
resource.deviceIds.insert(mindspore::GlobalContext::GetGlobalDeviceID());
|
resource.deviceIds.insert(device_id);
|
||||||
ascend_resource_ = ResourceManager::GetInstance();
|
ascend_resource_ = ResourceManager::GetInstance();
|
||||||
APP_ERROR ret = ascend_resource_->InitResource(resource);
|
APP_ERROR ret = ascend_resource_->InitResource(resource);
|
||||||
if (ret != APP_ERR_OK) {
|
if (ret != APP_ERR_OK) {
|
||||||
|
@ -35,8 +35,8 @@ Status AscendResource::InitResource() {
|
||||||
MS_LOG(ERROR) << err_msg;
|
MS_LOG(ERROR) << err_msg;
|
||||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
}
|
}
|
||||||
int device_id = *(resource.deviceIds.begin());
|
int cur_device_id = *(resource.deviceIds.begin());
|
||||||
aclrtContext context = ascend_resource_->GetContext(device_id);
|
aclrtContext context = ascend_resource_->GetContext(cur_device_id);
|
||||||
processor_ = std::make_shared<MDAclProcess>(context, false);
|
processor_ = std::make_shared<MDAclProcess>(context, false);
|
||||||
ret = processor_->InitResource();
|
ret = processor_->InitResource();
|
||||||
if (ret != APP_ERR_OK) {
|
if (ret != APP_ERR_OK) {
|
||||||
|
|
|
@ -36,7 +36,7 @@ class AscendResource : public DeviceResource {
|
||||||
AscendResource() = default;
|
AscendResource() = default;
|
||||||
~AscendResource() = default;
|
~AscendResource() = default;
|
||||||
|
|
||||||
Status InitResource() override;
|
Status InitResource(uint32_t device_id) override;
|
||||||
|
|
||||||
Status FinalizeResource() override;
|
Status FinalizeResource() override;
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
Status DeviceResource::InitResource() {
|
Status DeviceResource::InitResource(uint32_t) {
|
||||||
return Status(StatusCode::kMDUnexpectedError,
|
return Status(StatusCode::kMDUnexpectedError,
|
||||||
"Is this a valid device? If yes, please implement this InitResource() in the derived class.");
|
"Is this a valid device? If yes, please implement this InitResource() in the derived class.");
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ class DeviceResource {
|
||||||
|
|
||||||
virtual ~DeviceResource() = default;
|
virtual ~DeviceResource() = default;
|
||||||
|
|
||||||
virtual Status InitResource();
|
virtual Status InitResource(uint32_t device_id);
|
||||||
|
|
||||||
virtual Status FinalizeResource();
|
virtual Status FinalizeResource();
|
||||||
|
|
||||||
|
|
|
@ -34,18 +34,22 @@ class Execute {
|
||||||
public:
|
public:
|
||||||
/// \brief Constructor
|
/// \brief Constructor
|
||||||
// FIXME - Temporarily overload Execute to support both TensorOperation and TensorTransform
|
// FIXME - Temporarily overload Execute to support both TensorOperation and TensorTransform
|
||||||
explicit Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice deviceType = MapTargetDevice::kCpu);
|
explicit Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice deviceType = MapTargetDevice::kCpu,
|
||||||
explicit Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice deviceType = MapTargetDevice::kCpu);
|
uint32_t device_id = 0);
|
||||||
explicit Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice deviceType = MapTargetDevice::kCpu);
|
explicit Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice deviceType = MapTargetDevice::kCpu,
|
||||||
explicit Execute(TensorTransform *op, MapTargetDevice deviceType = MapTargetDevice::kCpu);
|
uint32_t device_id = 0);
|
||||||
|
explicit Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice deviceType = MapTargetDevice::kCpu,
|
||||||
|
uint32_t device_id = 0);
|
||||||
|
explicit Execute(TensorTransform *op, MapTargetDevice deviceType = MapTargetDevice::kCpu, uint32_t device_id = 0);
|
||||||
|
|
||||||
explicit Execute(std::vector<std::shared_ptr<TensorOperation>> ops,
|
explicit Execute(std::vector<std::shared_ptr<TensorOperation>> ops,
|
||||||
MapTargetDevice deviceType = MapTargetDevice::kCpu);
|
MapTargetDevice deviceType = MapTargetDevice::kCpu, uint32_t device_id = 0);
|
||||||
explicit Execute(std::vector<std::shared_ptr<TensorTransform>> ops,
|
explicit Execute(std::vector<std::shared_ptr<TensorTransform>> ops,
|
||||||
MapTargetDevice deviceType = MapTargetDevice::kCpu);
|
MapTargetDevice deviceType = MapTargetDevice::kCpu, uint32_t device_id = 0);
|
||||||
explicit Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops,
|
explicit Execute(const std::vector<std::reference_wrapper<TensorTransform>> ops,
|
||||||
MapTargetDevice deviceType = MapTargetDevice::kCpu);
|
MapTargetDevice deviceType = MapTargetDevice::kCpu, uint32_t device_id = 0);
|
||||||
explicit Execute(std::vector<TensorTransform *> ops, MapTargetDevice deviceType = MapTargetDevice::kCpu);
|
explicit Execute(std::vector<TensorTransform *> ops, MapTargetDevice deviceType = MapTargetDevice::kCpu,
|
||||||
|
uint32_t device_id = 0);
|
||||||
|
|
||||||
/// \brief Destructor
|
/// \brief Destructor
|
||||||
~Execute();
|
~Execute();
|
||||||
|
|
|
@ -78,7 +78,7 @@ Status DvppCropJpegOp::Compute(const std::shared_ptr<Tensor> &input, std::shared
|
||||||
imageinfo.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420;
|
imageinfo.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420;
|
||||||
ResourceInfo resource;
|
ResourceInfo resource;
|
||||||
resource.aclConfigPath = "";
|
resource.aclConfigPath = "";
|
||||||
resource.deviceIds.insert(mindspore::GlobalContext::GetGlobalDeviceID());
|
resource.deviceIds.insert(0);
|
||||||
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
||||||
APP_ERROR ret = instance->InitResource(resource);
|
APP_ERROR ret = instance->InitResource(resource);
|
||||||
if (ret != APP_ERR_OK) {
|
if (ret != APP_ERR_OK) {
|
||||||
|
|
|
@ -71,7 +71,7 @@ Status DvppDecodeJpegOp::Compute(const std::shared_ptr<Tensor> &input, std::shar
|
||||||
imageInfo.data = static_cast<void *>(buffer);
|
imageInfo.data = static_cast<void *>(buffer);
|
||||||
ResourceInfo resource;
|
ResourceInfo resource;
|
||||||
resource.aclConfigPath = "";
|
resource.aclConfigPath = "";
|
||||||
resource.deviceIds.insert(mindspore::GlobalContext::GetGlobalDeviceID());
|
resource.deviceIds.insert(0);
|
||||||
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
||||||
APP_ERROR ret = instance->InitResource(resource);
|
APP_ERROR ret = instance->InitResource(resource);
|
||||||
if (ret != APP_ERR_OK) {
|
if (ret != APP_ERR_OK) {
|
||||||
|
|
|
@ -69,7 +69,7 @@ Status DvppDecodePngOp::Compute(const std::shared_ptr<Tensor> &input, std::share
|
||||||
imageInfo.data = static_cast<void *>(buffer);
|
imageInfo.data = static_cast<void *>(buffer);
|
||||||
ResourceInfo resource;
|
ResourceInfo resource;
|
||||||
resource.aclConfigPath = "";
|
resource.aclConfigPath = "";
|
||||||
resource.deviceIds.insert(mindspore::GlobalContext::GetGlobalDeviceID());
|
resource.deviceIds.insert(0);
|
||||||
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
||||||
APP_ERROR ret = instance->InitResource(resource);
|
APP_ERROR ret = instance->InitResource(resource);
|
||||||
if (ret != APP_ERR_OK) {
|
if (ret != APP_ERR_OK) {
|
||||||
|
|
|
@ -70,7 +70,7 @@ Status DvppDecodeResizeCropJpegOp::Compute(const std::shared_ptr<Tensor> &input,
|
||||||
imageInfo.data = static_cast<void *>(buffer);
|
imageInfo.data = static_cast<void *>(buffer);
|
||||||
ResourceInfo resource;
|
ResourceInfo resource;
|
||||||
resource.aclConfigPath = "";
|
resource.aclConfigPath = "";
|
||||||
resource.deviceIds.insert(mindspore::GlobalContext::GetGlobalDeviceID());
|
resource.deviceIds.insert(0);
|
||||||
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
||||||
APP_ERROR ret = instance->InitResource(resource);
|
APP_ERROR ret = instance->InitResource(resource);
|
||||||
if (ret != APP_ERR_OK) {
|
if (ret != APP_ERR_OK) {
|
||||||
|
|
|
@ -69,7 +69,7 @@ Status DvppDecodeResizeJpegOp::Compute(const std::shared_ptr<Tensor> &input, std
|
||||||
imageInfo.data = static_cast<void *>(buffer);
|
imageInfo.data = static_cast<void *>(buffer);
|
||||||
ResourceInfo resource;
|
ResourceInfo resource;
|
||||||
resource.aclConfigPath = "";
|
resource.aclConfigPath = "";
|
||||||
resource.deviceIds.insert(mindspore::GlobalContext::GetGlobalDeviceID());
|
resource.deviceIds.insert(0);
|
||||||
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
||||||
APP_ERROR ret = instance->InitResource(resource);
|
APP_ERROR ret = instance->InitResource(resource);
|
||||||
if (ret != APP_ERR_OK) {
|
if (ret != APP_ERR_OK) {
|
||||||
|
|
|
@ -79,7 +79,7 @@ Status DvppResizeJpegOp::Compute(const std::shared_ptr<Tensor> &input, std::shar
|
||||||
imageinfo.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420;
|
imageinfo.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420;
|
||||||
ResourceInfo resource;
|
ResourceInfo resource;
|
||||||
resource.aclConfigPath = "";
|
resource.aclConfigPath = "";
|
||||||
resource.deviceIds.insert(mindspore::GlobalContext::GetGlobalDeviceID());
|
resource.deviceIds.insert(0);
|
||||||
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
std::shared_ptr<ResourceManager> instance = ResourceManager::GetInstance();
|
||||||
APP_ERROR ret = instance->InitResource(resource);
|
APP_ERROR ret = instance->InitResource(resource);
|
||||||
if (ret != APP_ERR_OK) {
|
if (ret != APP_ERR_OK) {
|
||||||
|
|
|
@ -16,8 +16,6 @@
|
||||||
#include "base/base_ref_utils.h"
|
#include "base/base_ref_utils.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "include/infer_tensor.h"
|
|
||||||
#include "ir/tensor.h"
|
#include "ir/tensor.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
|
@ -18,7 +18,6 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "include/infer_tensor.h"
|
|
||||||
#include "ir/tensor.h"
|
#include "ir/tensor.h"
|
||||||
#include "base/base_ref.h"
|
#include "base/base_ref.h"
|
||||||
|
|
||||||
|
|
|
@ -21,15 +21,17 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "include/ms_tensor.h"
|
#include "include/ms_tensor.h"
|
||||||
|
|
||||||
namespace mindspore::schema {
|
namespace mindspore {
|
||||||
struct Tensor;
|
class Allocator;
|
||||||
} // namespace mindspore::schema
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace schema {
|
||||||
|
struct Tensor;
|
||||||
|
} // namespace schema
|
||||||
|
|
||||||
|
namespace lite {
|
||||||
/// \brief Allocator defined a memory pool for malloc memory and free memory dynamically.
|
/// \brief Allocator defined a memory pool for malloc memory and free memory dynamically.
|
||||||
///
|
///
|
||||||
/// \note List public class and interface for reference.
|
/// \note List public class and interface for reference.
|
||||||
class Allocator;
|
|
||||||
|
|
||||||
/// \brief DeviceContext defined a device context.
|
/// \brief DeviceContext defined a device context.
|
||||||
struct DeviceContext;
|
struct DeviceContext;
|
||||||
|
@ -52,5 +54,6 @@ int MS_API StringsToMSTensor(const std::vector<std::string> &inputs, tensor::MST
|
||||||
/// \param[in] MSTensor.
|
/// \param[in] MSTensor.
|
||||||
/// \return string vector.
|
/// \return string vector.
|
||||||
std::vector<std::string> MS_API MSTensorToStrings(const tensor::MSTensor *tensor);
|
std::vector<std::string> MS_API MSTensorToStrings(const tensor::MSTensor *tensor);
|
||||||
} // namespace mindspore::lite
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_LITE_INCLUDE_LITE_UTILS_H_
|
#endif // MINDSPORE_LITE_INCLUDE_LITE_UTILS_H_
|
||||||
|
|
|
@ -45,6 +45,12 @@ class MS_API MSTensor {
|
||||||
/// \brief Destructor of MindSpore Lite Model.
|
/// \brief Destructor of MindSpore Lite Model.
|
||||||
virtual ~MSTensor() = default;
|
virtual ~MSTensor() = default;
|
||||||
|
|
||||||
|
/// \brief Create a MSTensor.
|
||||||
|
///
|
||||||
|
/// \return Pointer to an instance of MindSpore Lite MSTensor.
|
||||||
|
static MSTensor *CreateTensor(const std::string &name, TypeId type, const std::vector<int> &shape, const void *data,
|
||||||
|
size_t data_len);
|
||||||
|
|
||||||
/// \brief Get data type of the MindSpore Lite MSTensor.
|
/// \brief Get data type of the MindSpore Lite MSTensor.
|
||||||
///
|
///
|
||||||
/// \note TypeId is defined in mindspore/mindspore/include/api/type_id.h. Only number types in TypeId enum are
|
/// \note TypeId is defined in mindspore/mindspore/include/api/type_id.h. Only number types in TypeId enum are
|
||||||
|
@ -58,12 +64,8 @@ class MS_API MSTensor {
|
||||||
/// \return A vector of int as the shape of the MindSpore Lite MSTensor.
|
/// \return A vector of int as the shape of the MindSpore Lite MSTensor.
|
||||||
virtual std::vector<int> shape() const = 0;
|
virtual std::vector<int> shape() const = 0;
|
||||||
|
|
||||||
/// \brief Get size of the dimension of the MindSpore Lite MSTensor index by the parameter index.
|
/// \brief Set the shape of MSTensor.
|
||||||
///
|
virtual void set_shape(const std::vector<int> &name) = 0;
|
||||||
/// \param[in] index Define index of dimension returned.
|
|
||||||
///
|
|
||||||
/// \return Size of dimension of the MindSpore Lite MSTensor.
|
|
||||||
virtual int DimensionSize(size_t index) const = 0;
|
|
||||||
|
|
||||||
/// \brief Get number of element in MSTensor.
|
/// \brief Get number of element in MSTensor.
|
||||||
///
|
///
|
||||||
|
@ -75,13 +77,6 @@ class MS_API MSTensor {
|
||||||
/// \return Byte size of data in MSTensor.
|
/// \return Byte size of data in MSTensor.
|
||||||
virtual size_t Size() const = 0;
|
virtual size_t Size() const = 0;
|
||||||
|
|
||||||
/// \brief Get the pointer of data in MSTensor.
|
|
||||||
///
|
|
||||||
/// \note The data pointer can be used to both write and read data in MSTensor.
|
|
||||||
///
|
|
||||||
/// \return the pointer points to data in MSTensor.
|
|
||||||
virtual void *MutableData() = 0;
|
|
||||||
|
|
||||||
/// \brief Get the name of MSTensor.
|
/// \brief Get the name of MSTensor.
|
||||||
///
|
///
|
||||||
/// \return the name of MSTensor.
|
/// \return the name of MSTensor.
|
||||||
|
@ -90,6 +85,22 @@ class MS_API MSTensor {
|
||||||
/// \brief Set the name of MSTensor.
|
/// \brief Set the name of MSTensor.
|
||||||
virtual void set_tensor_name(const std::string name) = 0;
|
virtual void set_tensor_name(const std::string name) = 0;
|
||||||
|
|
||||||
|
/// \brief Get the pointer of data in MSTensor.
|
||||||
|
///
|
||||||
|
/// \note The data pointer can be used to both write and read data in MSTensor. The memory buffer will be
|
||||||
|
/// automatically allocated.
|
||||||
|
///
|
||||||
|
/// \return the pointer points to data in MSTensor.
|
||||||
|
virtual void *MutableData() = 0;
|
||||||
|
|
||||||
|
/// \brief Get the pointer of data in MSTensor.
|
||||||
|
///
|
||||||
|
/// \note The data pointer can be used to both write and read data in MSTensor. No memory buffer will be
|
||||||
|
/// allocated.
|
||||||
|
///
|
||||||
|
/// \return the pointer points to data in MSTensor.
|
||||||
|
virtual void *data() = 0;
|
||||||
|
|
||||||
/// \brief Set the data of MSTensor.
|
/// \brief Set the data of MSTensor.
|
||||||
virtual void set_data(void *data) = 0;
|
virtual void set_data(void *data) = 0;
|
||||||
};
|
};
|
||||||
|
|
|
@ -110,6 +110,7 @@ if(BUILD_MINDDATA STREQUAL "full")
|
||||||
${TOP_DIR}/mindspore/lite/src/cxx_api/types.cc
|
${TOP_DIR}/mindspore/lite/src/cxx_api/types.cc
|
||||||
${TOP_DIR}/mindspore/lite/src/cxx_api/tensor/tensor_impl.cc
|
${TOP_DIR}/mindspore/lite/src/cxx_api/tensor/tensor_impl.cc
|
||||||
${TOP_DIR}/mindspore/lite/src/tensor.cc
|
${TOP_DIR}/mindspore/lite/src/tensor.cc
|
||||||
|
${TOP_DIR}/mindspore/lite/src/common/string_util.cc
|
||||||
${CORE_DIR}/utils/status.cc
|
${CORE_DIR}/utils/status.cc
|
||||||
${MINDDATA_DIR}/api/datasets.cc
|
${MINDDATA_DIR}/api/datasets.cc
|
||||||
${MINDDATA_DIR}/kernels/data/data_utils.cc
|
${MINDDATA_DIR}/kernels/data/data_utils.cc
|
||||||
|
@ -304,7 +305,6 @@ elseif(BUILD_MINDDATA STREQUAL "wrapper")
|
||||||
set(MINDSPORE_LITE_CXXAPI_SRC
|
set(MINDSPORE_LITE_CXXAPI_SRC
|
||||||
${CORE_DIR}/utils/status.cc
|
${CORE_DIR}/utils/status.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../src/cxx_api/types.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../src/cxx_api/types.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../src/cxx_api/tensor/tensor_impl.cc
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../src/tensor.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../src/tensor.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ set(API_SRC
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/cell.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/cell.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/serialization.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/serialization.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/types.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/types.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/lite_context.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/context.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model/model.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model/model.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model/model_impl.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model/model_impl.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/graph/graph.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/graph/graph.cc
|
||||||
|
|
|
@ -53,6 +53,7 @@ int WriteStringsToTensor(Tensor *tensor, const std::vector<StringPack> &string_b
|
||||||
}
|
}
|
||||||
std::vector<int> shape = {offset[num]};
|
std::vector<int> shape = {offset[num]};
|
||||||
tensor->set_shape(shape);
|
tensor->set_shape(shape);
|
||||||
|
tensor->set_data_type(kObjectTypeString);
|
||||||
tensor->FreeData();
|
tensor->FreeData();
|
||||||
void *data = tensor->MutableData();
|
void *data = tensor->MutableData();
|
||||||
if (data == nullptr) {
|
if (data == nullptr) {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -14,7 +14,6 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "include/api/cell.h"
|
#include "include/api/cell.h"
|
||||||
#include "include/api/lite_context.h"
|
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -77,7 +76,7 @@ Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor>
|
||||||
return kLiteError;
|
return kLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GraphCell::Load() {
|
Status GraphCell::Load(uint32_t device_id) {
|
||||||
MS_LOG(ERROR) << "Unsupported feature.";
|
MS_LOG(ERROR) << "Unsupported feature.";
|
||||||
return kLiteError;
|
return kLiteError;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,266 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "include/api/context.h"
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <any>
|
||||||
|
#include "include/api/types.h"
|
||||||
|
#include "include/api/data_type.h"
|
||||||
|
#include "src/runtime/allocator.h"
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
|
||||||
|
constexpr auto kModelOptionCpuThreadAffinity = "mindspore.option.cpu.thread_affinity";
|
||||||
|
constexpr auto kModelOptionMaliGpuEnableFP16 = "mindspore.option.mali_gpu.enable_fp16";
|
||||||
|
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
|
||||||
|
|
||||||
|
struct Context::Data {
|
||||||
|
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
|
||||||
|
int32_t thread_num;
|
||||||
|
std::shared_ptr<Allocator> allocator;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DeviceInfoContext::Data {
|
||||||
|
std::map<std::string, std::any> params;
|
||||||
|
};
|
||||||
|
|
||||||
|
Context::Context() : data_(std::shared_ptr<Data>(new (std::nothrow) Data())) {}
|
||||||
|
|
||||||
|
template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
|
||||||
|
static const U &GetValue(const std::shared_ptr<DeviceInfoContext::Data> &data, const std::string &key) {
|
||||||
|
static U empty_result;
|
||||||
|
if (data == nullptr) {
|
||||||
|
return empty_result;
|
||||||
|
}
|
||||||
|
auto iter = data->params.find(key);
|
||||||
|
if (iter == data->params.end()) {
|
||||||
|
return empty_result;
|
||||||
|
}
|
||||||
|
const std::any &value = iter->second;
|
||||||
|
|
||||||
|
return std::any_cast<const U &>(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Context::SetThreadNum(int32_t thread_num) {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data_->thread_num = thread_num;
|
||||||
|
}
|
||||||
|
int32_t Context::GetThreadNum() const {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return data_->thread_num;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Context::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data_->allocator = allocator;
|
||||||
|
}
|
||||||
|
std::shared_ptr<Allocator> Context::GetAllocator() const {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return data_->allocator;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
|
||||||
|
static std::vector<std::shared_ptr<DeviceInfoContext>> empty;
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
return data_->device_info_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceInfoContext::DeviceInfoContext() : data_(std::shared_ptr<Data>(new (std::nothrow) Data())) {}
|
||||||
|
|
||||||
|
void CPUDeviceInfo::SetEnableFP16(bool is_fp16) {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data_->params[kModelOptionCpuEnableFP16] = is_fp16;
|
||||||
|
}
|
||||||
|
bool CPUDeviceInfo::GetEnableFP16() const {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return GetValue<bool>(data_, kModelOptionCpuEnableFP16);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CPUDeviceInfo::SetThreadAffinity(int affinity) {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data_->params[kModelOptionCpuThreadAffinity] = affinity;
|
||||||
|
}
|
||||||
|
int CPUDeviceInfo::GetThreadAffinity() const {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return GetValue<bool>(data_, kModelOptionCpuThreadAffinity);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MaliGPUDeviceInfo::SetEnableFP16(bool is_fp16) {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data_->params[kModelOptionMaliGpuEnableFP16] = is_fp16;
|
||||||
|
}
|
||||||
|
bool MaliGPUDeviceInfo::GetEnableFP16() const {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return GetValue<bool>(data_, kModelOptionMaliGpuEnableFP16);
|
||||||
|
}
|
||||||
|
|
||||||
|
void KirinNPUDeviceInfo::SetFrequency(int frequency) {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data_->params[kModelOptionKirinNpuFrequency] = frequency;
|
||||||
|
}
|
||||||
|
int KirinNPUDeviceInfo::GetFrequency() const {
|
||||||
|
if (data_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return GetValue<int>(data_, kModelOptionKirinNpuFrequency);
|
||||||
|
}
|
||||||
|
|
||||||
|
void NvidiaGPUDeviceInfo::SetDeviceID(uint32_t device_id) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||||
|
uint32_t NvidiaGPUDeviceInfo::GetDeviceID() const {
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NvidiaGPUDeviceInfo::SetGpuTrtInferMode(bool gpu_trt_infer_mode) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||||
|
bool NvidiaGPUDeviceInfo::GetGpuTrtInferMode() const {
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||||
|
uint32_t Ascend910DeviceInfo::GetDeviceID() const {
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||||
|
uint32_t Ascend310DeviceInfo::GetDeviceID() const {
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetDumpConfigPath(const std::vector<char> &cfg_path) {
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
}
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetDumpConfigPathChar() const {
|
||||||
|
std::vector<char> empty;
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
}
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
|
||||||
|
std::vector<char> empty;
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
|
||||||
|
std::vector<char> empty;
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
|
||||||
|
std::vector<char> empty;
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
}
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
|
||||||
|
std::vector<char> empty;
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
}
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
|
||||||
|
std::vector<char> empty;
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
}
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
|
||||||
|
std::vector<char> empty;
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
}
|
||||||
|
std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
|
||||||
|
std::vector<char> empty;
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
}
|
||||||
|
std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
|
||||||
|
std::map<int, std::vector<int>> empty;
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||||
|
enum DataType Ascend310DeviceInfo::GetOutputType() const {
|
||||||
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
|
return DataType::kTypeUnknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
|
Graph::Graph() : graph_data_(nullptr) {}
|
||||||
|
|
||||||
Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_data) {}
|
Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_data) {}
|
||||||
|
|
||||||
Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {}
|
Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {}
|
||||||
|
@ -30,5 +32,7 @@ Graph::Graph(std::nullptr_t) : graph_data_(nullptr) {}
|
||||||
|
|
||||||
bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; }
|
bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; }
|
||||||
|
|
||||||
|
bool Graph::operator!=(std::nullptr_t) const { return graph_data_ != nullptr; }
|
||||||
|
|
||||||
ModelType Graph::ModelType() const { return kMindIR; }
|
ModelType Graph::ModelType() const { return kMindIR; }
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,303 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "include/api/lite_context.h"
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include <any>
|
|
||||||
#include "include/api/types.h"
|
|
||||||
#include "src/common/log_adapter.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
|
|
||||||
constexpr char kVendorName[] = "vendor_name";
|
|
||||||
constexpr char kThreadNum[] = "thread_name";
|
|
||||||
constexpr char kAllocator[] = "allocator";
|
|
||||||
constexpr char kCPU[] = "cpu";
|
|
||||||
constexpr char kCPUEanbleFp16[] = "cpu_enable_fp16";
|
|
||||||
constexpr char kCPUBindMode[] = "cpu_bind_mode";
|
|
||||||
constexpr char kGPU[] = "gpu";
|
|
||||||
constexpr char kGPUEanbleFp16[] = "gpu_enable_fp16";
|
|
||||||
constexpr char kNPU[] = "npu";
|
|
||||||
constexpr char kNPUFrequency[] = "npu_frequency";
|
|
||||||
|
|
||||||
void Context::Clear(const std::shared_ptr<Context> &context) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
context->context_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Context::SetAsDefault(const std::shared_ptr<Context> &context) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
context->context_.clear();
|
|
||||||
context->context_.emplace(kCPU, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Context::SetVendorName(const std::shared_ptr<Context> &context, const std::string &name) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kVendorName);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
iter->second = name;
|
|
||||||
} else {
|
|
||||||
context->context_.emplace(kVendorName, name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string Context::GetVendorName(const std::shared_ptr<Context> &context) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return std::string();
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kVendorName);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
return std::any_cast<const std::string>(iter->second);
|
|
||||||
}
|
|
||||||
return std::string();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Context::SetThreadNum(const std::shared_ptr<Context> &context, int num) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kThreadNum);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
iter->second = num;
|
|
||||||
} else {
|
|
||||||
context->context_.emplace(kThreadNum, num);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int Context::GetThreadNum(const std::shared_ptr<Context> &context) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kThreadNum);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
return std::any_cast<int>(iter->second);
|
|
||||||
}
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Context::SetAllocator(const std::shared_ptr<Context> &context, std::shared_ptr<lite::Allocator> alloc) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kAllocator);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
iter->second = alloc;
|
|
||||||
} else {
|
|
||||||
context->context_.emplace(kAllocator, alloc);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<lite::Allocator> Context::GetAllocator(const std::shared_ptr<Context> &context) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kAllocator);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
return std::any_cast<std::shared_ptr<lite::Allocator>>(iter->second);
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Context::ConfigCPU(const std::shared_ptr<Context> &context, bool conf) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kCPU);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
iter->second = conf;
|
|
||||||
} else {
|
|
||||||
context->context_.emplace(kCPU, conf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Context::IfCPUEnabled(const std::shared_ptr<Context> &context) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kCPU);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
return std::any_cast<bool>(iter->second);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Context::ConfigCPUFp16(const std::shared_ptr<Context> &context, bool conf) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kCPUEanbleFp16);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
iter->second = conf;
|
|
||||||
} else {
|
|
||||||
context->context_.emplace(kCPUEanbleFp16, conf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Context::IfCPUFp16Enabled(const std::shared_ptr<Context> &context) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kCPUEanbleFp16);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
return std::any_cast<bool>(iter->second);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Context::SetCPUBindMode(const std::shared_ptr<Context> &context, lite::CpuBindMode mode) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kCPUBindMode);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
iter->second = mode;
|
|
||||||
} else {
|
|
||||||
context->context_.emplace(kCPUBindMode, mode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
lite::CpuBindMode Context::GetCPUBindMode(const std::shared_ptr<Context> &context) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return lite::NO_BIND;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kCPUBindMode);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
return std::any_cast<lite::CpuBindMode>(iter->second);
|
|
||||||
}
|
|
||||||
return lite::MID_CPU;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Context::ConfigGPU(const std::shared_ptr<Context> &context, bool conf) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kGPU);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
iter->second = conf;
|
|
||||||
} else {
|
|
||||||
context->context_.emplace(kGPU, conf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Context::IfGPUEnabled(const std::shared_ptr<Context> &context) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kGPU);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
return std::any_cast<bool>(iter->second);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Context::ConfigGPUFp16(const std::shared_ptr<Context> &context, bool conf) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kGPUEanbleFp16);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
iter->second = conf;
|
|
||||||
} else {
|
|
||||||
context->context_.emplace(kGPUEanbleFp16, conf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Context::IfGPUFp16Enabled(const std::shared_ptr<Context> &context) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kGPUEanbleFp16);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
return std::any_cast<bool>(iter->second);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Context::ConfigNPU(const std::shared_ptr<Context> &context, bool conf) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kNPU);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
iter->second = conf;
|
|
||||||
} else {
|
|
||||||
context->context_.emplace(kNPU, conf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Context::IfNPUEnabled(const std::shared_ptr<Context> &context) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kNPU);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
return std::any_cast<bool>(iter->second);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Context::SetNPUFrequency(const std::shared_ptr<Context> &context, int freq) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kNPUFrequency);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
iter->second = freq;
|
|
||||||
} else {
|
|
||||||
context->context_.emplace(kNPUFrequency, freq);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int Context::GetNPUFrequency(const std::shared_ptr<Context> &context) {
|
|
||||||
if (context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Context is nullptr.";
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
auto iter = context->context_.find(kNPUFrequency);
|
|
||||||
if (iter != context->context_.end()) {
|
|
||||||
return std::any_cast<int>(iter->second);
|
|
||||||
}
|
|
||||||
return 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -14,17 +14,30 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "include/api/model.h"
|
#include "include/api/model.h"
|
||||||
#include "include/api/lite_context.h"
|
#include "include/api/types.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
#include "include/api/dual_abi_helper.h"
|
||||||
#include "src/cxx_api/model/model_impl.h"
|
#include "src/cxx_api/model/model_impl.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
Status Model::Build() {
|
Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_context) {
|
||||||
|
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Model implement is null.";
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
return kLiteNullptr;
|
return kLiteNullptr;
|
||||||
}
|
}
|
||||||
|
if (graph.GetGraph() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid graph.";
|
||||||
|
return kLiteNullptr;
|
||||||
|
}
|
||||||
|
if (model_context == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid context.";
|
||||||
|
return kLiteNullptr;
|
||||||
|
}
|
||||||
|
impl_->SetContext(model_context);
|
||||||
|
impl_->SetGraph(graph.GetGraph());
|
||||||
return impl_->Build();
|
return impl_->Build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,30 +57,11 @@ Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor>
|
||||||
return impl_->Predict(inputs, outputs);
|
return impl_->Predict(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
Model::Model(const GraphCell &graph, const std::shared_ptr<Context> &model_context) {
|
Model::Model() : impl_(nullptr) {}
|
||||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
|
||||||
if (impl_ == nullptr || graph.GetGraph() == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Invalid graph.";
|
|
||||||
} else if (model_context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Invalid context.";
|
|
||||||
} else {
|
|
||||||
auto new_graph_cell = std::shared_ptr<GraphCell>(new (std::nothrow) GraphCell(graph));
|
|
||||||
if (new_graph_cell != nullptr) {
|
|
||||||
impl_->SetContext(model_context);
|
|
||||||
impl_->SetGraphCell(new_graph_cell);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "New graphcell failed.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context> &model_context) {
|
|
||||||
MS_LOG(ERROR) << "Unsupported feature.";
|
|
||||||
}
|
|
||||||
|
|
||||||
Model::~Model() {}
|
Model::~Model() {}
|
||||||
|
|
||||||
bool Model::CheckModelSupport(const std::vector<char> &, ModelType) {
|
bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) {
|
||||||
MS_LOG(ERROR) << "Unsupported feature.";
|
MS_LOG(ERROR) << "Unsupported feature.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -90,4 +84,37 @@ std::vector<MSTensor> Model::GetOutputs() {
|
||||||
return impl_->GetOutputs();
|
return impl_->GetOutputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MSTensor Model::GetInputByTensorName(const std::vector<char> &name) {
|
||||||
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
|
return MSTensor(nullptr);
|
||||||
|
}
|
||||||
|
return impl_->GetInputByTensorName(CharToString(name));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<char>> Model::GetOutputTensorNamesChar() {
|
||||||
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
|
std::vector<std::vector<char>> empty;
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
return VectorStringToChar(impl_->GetOutputTensorNames());
|
||||||
|
}
|
||||||
|
|
||||||
|
MSTensor Model::GetOutputByTensorName(const std::vector<char> &name) {
|
||||||
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
|
return MSTensor(nullptr);
|
||||||
|
}
|
||||||
|
return impl_->GetOutputByTensorName(CharToString(name));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::vector<char> &node_name) {
|
||||||
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
|
std::vector<MSTensor> empty;
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
return impl_->GetOutputsByNodeName(CharToString(node_name));
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -19,14 +19,16 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "include/api/types.h"
|
#include "include/api/types.h"
|
||||||
#include "include/api/lite_context.h"
|
#include "include/api/context.h"
|
||||||
|
#include "include/api/dual_abi_helper.h"
|
||||||
#include "include/lite_session.h"
|
#include "include/lite_session.h"
|
||||||
#include "include/context.h"
|
#include "include/context.h"
|
||||||
#include "src/lite_model.h"
|
#include "src/lite_model.h"
|
||||||
#include "src/runtime/allocator.h"
|
#include "src/runtime/allocator.h"
|
||||||
#include "src/cxx_api/utils.h"
|
#include "src/common/string_util.h"
|
||||||
#include "src/cxx_api/graph/graph_data.h"
|
#include "src/cxx_api/graph/graph_data.h"
|
||||||
#include "src/cxx_api/tensor/tensor_impl.h"
|
#include "src/cxx_api/tensor/tensor_impl.h"
|
||||||
|
#include "src/cxx_api/tensor_utils.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -39,13 +41,9 @@ Status ModelImpl::Build() {
|
||||||
MS_LOG(DEBUG) << "Model has been already built.";
|
MS_LOG(DEBUG) << "Model has been already built.";
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
if (graph_cell_ == nullptr || graph_cell_->GetGraph() == nullptr || graph_cell_->GetGraph()->graph_data_ == nullptr) {
|
auto model = graph_->graph_data_->lite_model();
|
||||||
MS_LOG(ERROR) << "Graph cell is invalid.";
|
if (graph_ == nullptr || graph_->graph_data_ == nullptr || model == nullptr) {
|
||||||
return kLiteNullptr;
|
MS_LOG(ERROR) << "Invalid graph.";
|
||||||
}
|
|
||||||
auto model = graph_cell_->GetGraph()->graph_data_->lite_model();
|
|
||||||
if (model == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Lite model is nullptr.";
|
|
||||||
return kLiteNullptr;
|
return kLiteNullptr;
|
||||||
}
|
}
|
||||||
if (model->buf == nullptr) {
|
if (model->buf == nullptr) {
|
||||||
|
@ -57,36 +55,58 @@ Status ModelImpl::Build() {
|
||||||
return kLiteNullptr;
|
return kLiteNullptr;
|
||||||
}
|
}
|
||||||
lite::Context model_context;
|
lite::Context model_context;
|
||||||
model_context.allocator = Context::GetAllocator(context_);
|
auto device_list = context_->MutableDeviceInfo();
|
||||||
|
if (device_list.size() == 0) {
|
||||||
|
MS_LOG(ERROR) << "Invalid device list.";
|
||||||
|
return kLiteInputParamInvalid;
|
||||||
|
}
|
||||||
|
if (device_list.size() > 2) {
|
||||||
|
MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported.";
|
||||||
|
return kLiteInputParamInvalid;
|
||||||
|
}
|
||||||
|
model_context.allocator = context_->GetAllocator();
|
||||||
if (model_context.allocator == nullptr) {
|
if (model_context.allocator == nullptr) {
|
||||||
model_context.allocator = lite::Allocator::Create();
|
model_context.allocator = Allocator::Create();
|
||||||
if (model_context.allocator == nullptr) {
|
if (model_context.allocator == nullptr) {
|
||||||
MS_LOG(ERROR) << "Create Allocator failed.";
|
MS_LOG(ERROR) << "Create Allocator failed.";
|
||||||
return kLiteNullptr;
|
return kLiteNullptr;
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "Set new allocator.";
|
MS_LOG(DEBUG) << "Set new allocator.";
|
||||||
Context::SetAllocator(context_, model_context.allocator);
|
context_->SetAllocator(model_context.allocator);
|
||||||
}
|
}
|
||||||
model_context.vendor_name_ = Context::GetVendorName(context_);
|
model_context.thread_num_ = context_->GetThreadNum();
|
||||||
model_context.thread_num_ = Context::GetThreadNum(context_);
|
|
||||||
model_context.device_list_.clear();
|
model_context.device_list_.clear();
|
||||||
if (Context::IfCPUEnabled(context_) && Context::IfGPUEnabled(context_) && Context::IfNPUEnabled(context_)) {
|
if (device_list[0]->GetDeviceType() != kCPU) {
|
||||||
MS_LOG(ERROR) << "CPU/GPU/NPU cannot be enabled at the same time.";
|
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
|
||||||
return kLiteInputParamInvalid;
|
return kLiteInputParamInvalid;
|
||||||
}
|
}
|
||||||
if (!Context::IfCPUEnabled(context_)) {
|
auto cpu_context = device_list[0]->Cast<CPUDeviceInfo>();
|
||||||
MS_LOG(INFO) << "CPU is forced to be enabled.";
|
lite::CpuBindMode mode;
|
||||||
|
if (cpu_context->GetThreadAffinity() == 0) {
|
||||||
|
mode = lite::NO_BIND;
|
||||||
|
} else if (cpu_context->GetThreadAffinity() == 1) {
|
||||||
|
mode = lite::HIGHER_CPU;
|
||||||
|
} else if (cpu_context->GetThreadAffinity() == 2) {
|
||||||
|
mode = lite::MID_CPU;
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Invalid thread affinity.";
|
||||||
|
return kLiteInputParamInvalid;
|
||||||
}
|
}
|
||||||
lite::DeviceInfo cpu_info = {
|
lite::DeviceInfo cpu_info = {.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode}};
|
||||||
.cpu_device_info_ = {Context::IfCPUFp16Enabled(context_), Context::GetCPUBindMode(context_)}};
|
|
||||||
model_context.device_list_.push_back({lite::DT_CPU, cpu_info});
|
model_context.device_list_.push_back({lite::DT_CPU, cpu_info});
|
||||||
if (Context::IfGPUEnabled(context_)) {
|
if (device_list.size() == 2) {
|
||||||
lite::DeviceInfo gpu_info = {.gpu_device_info_ = {Context::IfGPUFp16Enabled(context_)}};
|
if (device_list[0]->GetDeviceType() == kMaliGPU) {
|
||||||
model_context.device_list_.push_back({lite::DT_GPU, gpu_info});
|
auto gpu_context = device_list[0]->Cast<MaliGPUDeviceInfo>();
|
||||||
}
|
lite::DeviceInfo gpu_info = {.gpu_device_info_ = {gpu_context->GetEnableFP16()}};
|
||||||
if (Context::IfNPUEnabled(context_)) {
|
model_context.device_list_.push_back({lite::DT_GPU, gpu_info});
|
||||||
lite::DeviceInfo npu_info = {.npu_device_info_ = {Context::GetNPUFrequency(context_)}};
|
} else if (device_list[0]->GetDeviceType() == kKirinNPU) {
|
||||||
model_context.device_list_.push_back({lite::DT_NPU, npu_info});
|
auto npu_context = device_list[0]->Cast<KirinNPUDeviceInfo>();
|
||||||
|
lite::DeviceInfo npu_info = {.npu_device_info_ = {npu_context->GetFrequency()}};
|
||||||
|
model_context.device_list_.push_back({lite::DT_NPU, npu_info});
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Invalid device.";
|
||||||
|
return kLiteInputParamInvalid;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&model_context));
|
auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&model_context));
|
||||||
if (session == nullptr) {
|
if (session == nullptr) {
|
||||||
|
@ -98,12 +118,19 @@ Status ModelImpl::Build() {
|
||||||
MS_LOG(ERROR) << "Build model failed.";
|
MS_LOG(ERROR) << "Build model failed.";
|
||||||
return static_cast<StatusCode>(ret);
|
return static_cast<StatusCode>(ret);
|
||||||
}
|
}
|
||||||
|
session->BindThread(true);
|
||||||
session_.swap(session);
|
session_.swap(session);
|
||||||
model->Free();
|
model->Free();
|
||||||
MS_LOG(DEBUG) << "Build model success.";
|
MS_LOG(DEBUG) << "Build model success.";
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ResetTensorData(std::vector<void *> old_data, std::vector<tensor::MSTensor *> tensors) {
|
||||||
|
for (size_t j = 0; j < old_data.size(); j++) {
|
||||||
|
tensors.at(j)->set_data(old_data.at(j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
|
||||||
if (session_ == nullptr) {
|
if (session_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Run graph failed.";
|
MS_LOG(ERROR) << "Run graph failed.";
|
||||||
|
@ -122,35 +149,44 @@ Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
|
||||||
for (size_t i = 0; i < inputs.size(); i++) {
|
for (size_t i = 0; i < inputs.size(); i++) {
|
||||||
auto input = input_tensors.at(i);
|
auto input = input_tensors.at(i);
|
||||||
auto user_input = inputs.at(i);
|
auto user_input = inputs.at(i);
|
||||||
|
if (user_input.DataType() != static_cast<enum DataType>(input->data_type())) {
|
||||||
|
ResetTensorData(old_data, input_tensors);
|
||||||
|
MS_LOG(ERROR) << "Tensor " << user_input.Name() << " has a different data type from input" << input->tensor_name()
|
||||||
|
<< ".";
|
||||||
|
return kLiteInputTensorError;
|
||||||
|
}
|
||||||
if (user_input.Name() != input->tensor_name()) {
|
if (user_input.Name() != input->tensor_name()) {
|
||||||
MS_LOG(WARNING) << "Tensor " << user_input.Name() << " has a different name from input" << input->tensor_name()
|
MS_LOG(WARNING) << "Tensor " << user_input.Name() << " has a different name from input" << input->tensor_name()
|
||||||
<< ".";
|
<< ".";
|
||||||
}
|
}
|
||||||
old_data.push_back(input->MutableData());
|
old_data.push_back(input->data());
|
||||||
if (user_input.MutableData() != input->MutableData()) {
|
if (input->data_type() == kObjectTypeString) {
|
||||||
if (input->Size() != user_input.DataSize()) {
|
std::vector<int32_t> shape = TruncateShape(user_input.Shape(), input->data_type(), user_input.DataSize(), false);
|
||||||
for (size_t j = 0; j < old_data.size(); j++) {
|
if (shape.empty() && !(user_input.Shape().empty())) {
|
||||||
input_tensors.at(j)->set_data(old_data.at(j));
|
ResetTensorData(old_data, input_tensors);
|
||||||
}
|
MS_LOG(ERROR) << "Input dims of tensor " << user_input.Name() << " is invalid.";
|
||||||
MS_LOG(ERROR) << "Tensor " << user_input.Name() << " has wrong data size.";
|
return kLiteParamInvalid;
|
||||||
return kLiteInputTensorError;
|
|
||||||
}
|
}
|
||||||
if (user_input.impl_->need_copy()) {
|
input->set_shape(shape);
|
||||||
::memcpy(input->MutableData(), user_input.MutableData(), input->Size());
|
input->set_data(user_input.MutableData());
|
||||||
} else {
|
} else {
|
||||||
|
if (user_input.MutableData() != input->data()) {
|
||||||
|
if (input->Size() != user_input.DataSize()) {
|
||||||
|
ResetTensorData(old_data, input_tensors);
|
||||||
|
MS_LOG(ERROR) << "Tensor " << user_input.Name() << " has wrong data size.";
|
||||||
|
return kLiteInputTensorError;
|
||||||
|
}
|
||||||
input->set_data(user_input.MutableData());
|
input->set_data(user_input.MutableData());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto ret = session_->RunGraph();
|
auto ret = session_->RunGraph();
|
||||||
|
ResetTensorData(old_data, input_tensors);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Run graph failed.";
|
MS_LOG(ERROR) << "Run graph failed.";
|
||||||
return static_cast<StatusCode>(ret);
|
return static_cast<StatusCode>(ret);
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "Run graph success.";
|
MS_LOG(DEBUG) << "Run graph success.";
|
||||||
for (size_t i = 0; i < old_data.size(); i++) {
|
|
||||||
input_tensors.at(i)->set_data(old_data.at(i));
|
|
||||||
}
|
|
||||||
auto res = GetOutputs();
|
auto res = GetOutputs();
|
||||||
if (res.empty()) {
|
if (res.empty()) {
|
||||||
MS_LOG(DEBUG) << "Empty outputs.";
|
MS_LOG(DEBUG) << "Empty outputs.";
|
||||||
|
@ -176,7 +212,7 @@ std::vector<MSTensor> ModelImpl::GetInputs() {
|
||||||
res.resize(inputs.size());
|
res.resize(inputs.size());
|
||||||
for (size_t i = 0; i < inputs.size(); i++) {
|
for (size_t i = 0; i < inputs.size(); i++) {
|
||||||
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(inputs[i]));
|
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(inputs[i]));
|
||||||
if (impl == nullptr) {
|
if (impl == nullptr || impl->lite_tensor() == nullptr) {
|
||||||
MS_LOG(ERROR) << "Create tensor failed.";
|
MS_LOG(ERROR) << "Create tensor failed.";
|
||||||
return empty;
|
return empty;
|
||||||
}
|
}
|
||||||
|
@ -214,7 +250,83 @@ std::vector<MSTensor> ModelImpl::GetOutputs() {
|
||||||
res.resize(names.size());
|
res.resize(names.size());
|
||||||
for (size_t i = 0; i < names.size(); i++) {
|
for (size_t i = 0; i < names.size(); i++) {
|
||||||
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(outputs[names[i]]));
|
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(outputs[names[i]]));
|
||||||
if (impl == nullptr) {
|
if (impl == nullptr || impl->lite_tensor() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor failed.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
auto tensor = MSTensor(impl);
|
||||||
|
if (tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor failed.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
res[i] = tensor;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
MSTensor ModelImpl::GetInputByTensorName(const std::string &name) {
|
||||||
|
if (session_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Session is null.";
|
||||||
|
return MSTensor(nullptr);
|
||||||
|
}
|
||||||
|
auto res = session_->GetInputsByTensorName(name);
|
||||||
|
if (res == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Model does not contains tensor " << name << " .";
|
||||||
|
return MSTensor(nullptr);
|
||||||
|
}
|
||||||
|
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(res));
|
||||||
|
if (impl == nullptr || impl->lite_tensor() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor failed.";
|
||||||
|
return MSTensor(nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
return MSTensor(impl);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> ModelImpl::GetOutputTensorNames() {
|
||||||
|
if (session_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Session is null.";
|
||||||
|
std::vector<std::string> empty;
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
return session_->GetOutputTensorNames();
|
||||||
|
}
|
||||||
|
|
||||||
|
MSTensor ModelImpl::GetOutputByTensorName(const std::string &name) {
|
||||||
|
if (session_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Session is null.";
|
||||||
|
return MSTensor(nullptr);
|
||||||
|
}
|
||||||
|
auto res = session_->GetOutputByTensorName(name);
|
||||||
|
if (res == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Model does not contains tensor " << name << " .";
|
||||||
|
return MSTensor(nullptr);
|
||||||
|
}
|
||||||
|
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(res));
|
||||||
|
if (impl == nullptr || impl->lite_tensor() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor failed.";
|
||||||
|
return MSTensor(nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
return MSTensor(impl);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<MSTensor> ModelImpl::GetOutputsByNodeName(const std::string &name) {
|
||||||
|
std::vector<MSTensor> empty;
|
||||||
|
if (session_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Session is null.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
std::vector<MSTensor> res;
|
||||||
|
auto outputs = session_->GetOutputsByNodeName(name);
|
||||||
|
if (outputs.empty()) {
|
||||||
|
MS_LOG(ERROR) << "The outputs of model is null.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
res.resize(outputs.size());
|
||||||
|
for (size_t i = 0; i < outputs.size(); i++) {
|
||||||
|
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(outputs[i]));
|
||||||
|
if (impl == nullptr || impl->lite_tensor() == nullptr) {
|
||||||
MS_LOG(ERROR) << "Create tensor failed.";
|
MS_LOG(ERROR) << "Create tensor failed.";
|
||||||
return empty;
|
return empty;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -23,14 +23,14 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "include/api/model.h"
|
#include "include/api/model.h"
|
||||||
#include "include/api/lite_context.h"
|
#include "include/api/context.h"
|
||||||
#include "include/api/cell.h"
|
#include "include/api/cell.h"
|
||||||
#include "include/lite_session.h"
|
#include "include/lite_session.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class ModelImpl {
|
class ModelImpl {
|
||||||
public:
|
public:
|
||||||
ModelImpl() : graph_cell_(nullptr), session_(nullptr), context_(nullptr) {}
|
ModelImpl() : graph_(nullptr), session_(nullptr), context_(nullptr) {}
|
||||||
~ModelImpl() = default;
|
~ModelImpl() = default;
|
||||||
|
|
||||||
Status Build();
|
Status Build();
|
||||||
|
@ -40,15 +40,19 @@ class ModelImpl {
|
||||||
|
|
||||||
std::vector<MSTensor> GetInputs();
|
std::vector<MSTensor> GetInputs();
|
||||||
std::vector<MSTensor> GetOutputs();
|
std::vector<MSTensor> GetOutputs();
|
||||||
|
MSTensor GetInputByTensorName(const std::string &name);
|
||||||
|
std::vector<std::string> GetOutputTensorNames();
|
||||||
|
MSTensor GetOutputByTensorName(const std::string &name);
|
||||||
|
std::vector<MSTensor> GetOutputsByNodeName(const std::string &name);
|
||||||
|
|
||||||
static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
|
static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Model;
|
friend class Model;
|
||||||
std::shared_ptr<GraphCell> graph_cell_;
|
std::shared_ptr<Graph> graph_;
|
||||||
std::shared_ptr<session::LiteSession> session_;
|
std::shared_ptr<session::LiteSession> session_;
|
||||||
std::shared_ptr<Context> context_;
|
std::shared_ptr<Context> context_;
|
||||||
void SetGraphCell(const std::shared_ptr<GraphCell> &graph_cell) { graph_cell_ = graph_cell; }
|
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
|
||||||
void SetContext(const std::shared_ptr<Context> &context) { context_ = context; }
|
void SetContext(const std::shared_ptr<Context> &context) { context_ = context; }
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -19,7 +19,7 @@
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include "include/api/graph.h"
|
#include "include/api/graph.h"
|
||||||
#include "include/api/lite_context.h"
|
#include "include/api/context.h"
|
||||||
#include "include/api/types.h"
|
#include "include/api/types.h"
|
||||||
#include "include/model.h"
|
#include "include/model.h"
|
||||||
#include "include/ms_tensor.h"
|
#include "include/ms_tensor.h"
|
||||||
|
@ -28,28 +28,28 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelType model_type) {
|
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph) {
|
||||||
if (model_type != kMindIR) {
|
if (model_type != kMindIR) {
|
||||||
MS_LOG(ERROR) << "Unsupported IR.";
|
MS_LOG(ERROR) << "Unsupported IR.";
|
||||||
return Graph(nullptr);
|
return kLiteInputParamInvalid;
|
||||||
}
|
}
|
||||||
auto model = std::shared_ptr<lite::Model>(lite::Model::Import(static_cast<const char *>(model_data), data_size));
|
auto model = std::shared_ptr<lite::Model>(lite::Model::Import(static_cast<const char *>(model_data), data_size));
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
MS_LOG(ERROR) << "New model failed.";
|
MS_LOG(ERROR) << "New model failed.";
|
||||||
return Graph(nullptr);
|
return kLiteNullptr;
|
||||||
}
|
}
|
||||||
auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model));
|
auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model));
|
||||||
if (graph_data == nullptr) {
|
if (graph_data == nullptr) {
|
||||||
MS_LOG(ERROR) << "New graph data failed.";
|
MS_LOG(ERROR) << "New graph data failed.";
|
||||||
return Graph(nullptr);
|
return kLiteMemoryFailed;
|
||||||
}
|
}
|
||||||
Graph graph = Graph(graph_data);
|
*graph = Graph(graph_data);
|
||||||
return graph;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
Graph Serialization::LoadModel(const std::vector<char> &file, ModelType model_type) {
|
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
|
||||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
return Graph(nullptr);
|
return kLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
|
Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,27 +13,69 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "src/cxx_api/tensor/tensor_impl.h"
|
#include <cstddef>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include "src/cxx_api/tensor/tensor_impl.h"
|
||||||
|
#include "src/cxx_api/tensor_utils.h"
|
||||||
#include "include/api/types.h"
|
#include "include/api/types.h"
|
||||||
#include "include/api/status.h"
|
#include "include/api/status.h"
|
||||||
#include "src/cxx_api/utils.h"
|
#include "include/ms_tensor.h"
|
||||||
|
#include "src/common/string_util.h"
|
||||||
|
#include "src/tensor.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
#include "ir/dtype/type_id.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
MSTensor::Impl::Impl(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
using mindspore::lite::RET_OK;
|
||||||
size_t data_len) {
|
|
||||||
|
MSTensor::Impl *MSTensor::Impl::CreateTensorImpl(const std::string &name, enum DataType type,
|
||||||
|
const std::vector<int64_t> &shape, const void *data, size_t data_len) {
|
||||||
std::vector<int32_t> truncated_shape = TruncateShape(shape, static_cast<enum TypeId>(type), data_len, true);
|
std::vector<int32_t> truncated_shape = TruncateShape(shape, static_cast<enum TypeId>(type), data_len, true);
|
||||||
if (truncated_shape.empty() && !(shape.empty())) {
|
if (truncated_shape.empty() && !(shape.empty())) {
|
||||||
lite_tensor_ = nullptr;
|
MS_LOG(ERROR) << "Invalid shape for creating tensor.";
|
||||||
} else {
|
return nullptr;
|
||||||
lite_tensor_ = new (std::nothrow) lite::Tensor(name, static_cast<enum TypeId>(type), truncated_shape, data);
|
|
||||||
}
|
}
|
||||||
|
auto lite_tensor = lite::Tensor::CreateTensor(name, static_cast<enum TypeId>(type), truncated_shape, data, data_len);
|
||||||
|
if (lite_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed to allocate lite tensor.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto impl = new (std::nothrow) Impl();
|
||||||
|
if (impl == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed to allocate tensor impl.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
impl->set_lite_tensor(lite_tensor);
|
||||||
|
return impl;
|
||||||
|
}
|
||||||
|
|
||||||
|
MSTensor::Impl *MSTensor::Impl::StringsToTensorImpl(const std::string &name, const std::vector<std::string> &str) {
|
||||||
|
auto lite_tensor = new (std::nothrow) lite::Tensor();
|
||||||
|
if (lite_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed to allocate lite tensor.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
lite_tensor->set_tensor_name(name);
|
||||||
|
auto ret = lite::StringsToMSTensor(str, lite_tensor);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Convert strings to tensor failed.";
|
||||||
|
delete lite_tensor;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto impl = new (std::nothrow) Impl();
|
||||||
|
if (impl == nullptr) {
|
||||||
|
delete lite_tensor;
|
||||||
|
MS_LOG(ERROR) << "Failed to allocate tensor impl.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
impl->set_lite_tensor(lite_tensor);
|
||||||
|
impl->set_own_data(true);
|
||||||
|
return impl;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -22,25 +22,51 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include "include/api/types.h"
|
#include "include/api/types.h"
|
||||||
#include "include/api/status.h"
|
#include "include/api/status.h"
|
||||||
|
#include "include/lite_utils.h"
|
||||||
#include "include/ms_tensor.h"
|
#include "include/ms_tensor.h"
|
||||||
#include "src/tensor.h"
|
#include "src/tensor.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
using mindspore::lite::RET_OK;
|
||||||
|
|
||||||
class MSTensor::Impl {
|
class MSTensor::Impl {
|
||||||
public:
|
public:
|
||||||
Impl() {}
|
Impl() {}
|
||||||
virtual ~Impl() = default;
|
|
||||||
explicit Impl(tensor::MSTensor *tensor) : lite_tensor_(tensor) {
|
virtual ~Impl() {
|
||||||
|
if (lite_tensor_ == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!from_session_) {
|
||||||
|
if (!own_data_) {
|
||||||
|
lite_tensor_->set_data(nullptr);
|
||||||
|
}
|
||||||
|
delete lite_tensor_;
|
||||||
|
lite_tensor_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit Impl(tensor::MSTensor *tensor) : lite_tensor_(tensor), from_session_(true) {
|
||||||
if (tensor != nullptr) {
|
if (tensor != nullptr) {
|
||||||
tensor_name_ = tensor->tensor_name();
|
tensor_name_ = tensor->tensor_name();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator==(std::nullptr_t) const { return lite_tensor_ == nullptr; }
|
static Impl *CreateTensorImpl(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||||
|
const void *data, size_t data_len);
|
||||||
|
|
||||||
Impl(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
static Impl *StringsToTensorImpl(const std::string &name, const std::vector<std::string> &str);
|
||||||
size_t data_len);
|
|
||||||
|
static std::vector<std::string> TensorImplToStrings(const std::shared_ptr<Impl> &impl) {
|
||||||
|
std::vector<std::string> empty;
|
||||||
|
auto lite_tensor = impl->lite_tensor();
|
||||||
|
if (lite_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid tensor impl.";
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
return lite::MSTensorToStrings(lite_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
virtual const std::string &Name() const {
|
virtual const std::string &Name() const {
|
||||||
static std::string empty = "";
|
static std::string empty = "";
|
||||||
|
@ -110,11 +136,6 @@ class MSTensor::Impl {
|
||||||
|
|
||||||
virtual bool IsDevice() const { return false; }
|
virtual bool IsDevice() const { return false; }
|
||||||
|
|
||||||
virtual std::shared_ptr<Impl> Clone() const {
|
|
||||||
MS_LOG(ERROR) << "Unsupported feature.";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
tensor::MSTensor *lite_tensor() { return lite_tensor_; }
|
tensor::MSTensor *lite_tensor() { return lite_tensor_; }
|
||||||
|
|
||||||
Status set_lite_tensor(tensor::MSTensor *tensor) {
|
Status set_lite_tensor(tensor::MSTensor *tensor) {
|
||||||
|
@ -126,15 +147,14 @@ class MSTensor::Impl {
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_need_copy(bool need_copy) { need_copy_ = need_copy; }
|
void set_own_data(bool own_data) { own_data_ = own_data; }
|
||||||
|
|
||||||
bool need_copy() { return need_copy_; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
tensor::MSTensor *lite_tensor_;
|
tensor::MSTensor *lite_tensor_ = nullptr;
|
||||||
std::string tensor_name_;
|
std::string tensor_name_ = "";
|
||||||
std::vector<int64_t> shape_;
|
std::vector<int64_t> shape_ = {};
|
||||||
bool need_copy_ = true;
|
bool own_data_ = false;
|
||||||
|
bool from_session_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
*/
|
*/
|
||||||
#include <limits.h>
|
#include <limits.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "src/tensor.h"
|
#include "ir/dtype/type_id.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
static std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
|
static std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -18,7 +18,9 @@
|
||||||
#include <limits.h>
|
#include <limits.h>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include "include/api/status.h"
|
#include "include/api/status.h"
|
||||||
|
#include "include/api/dual_abi_helper.h"
|
||||||
#include "src/cxx_api/tensor/tensor_impl.h"
|
#include "src/cxx_api/tensor/tensor_impl.h"
|
||||||
|
#include "src/common/string_util.h"
|
||||||
#include "src/tensor.h"
|
#include "src/tensor.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
|
@ -62,40 +64,106 @@ MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {}
|
||||||
MSTensor::MSTensor(const std::shared_ptr<Impl> &impl) : impl_(impl) {}
|
MSTensor::MSTensor(const std::shared_ptr<Impl> &impl) : impl_(impl) {}
|
||||||
MSTensor::MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
MSTensor::MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||||
const void *data, size_t data_len)
|
const void *data, size_t data_len)
|
||||||
: impl_(std::make_shared<Impl>(CharToString(name), type, shape, data, data_len)) {}
|
: impl_(std::shared_ptr<Impl>(Impl::CreateTensorImpl(CharToString(name), type, shape, data, data_len))) {}
|
||||||
MSTensor::~MSTensor() = default;
|
MSTensor::~MSTensor() = default;
|
||||||
|
|
||||||
bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; }
|
bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; }
|
||||||
|
|
||||||
MSTensor MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
bool MSTensor::operator!=(std::nullptr_t) const { return impl_ != nullptr; }
|
||||||
const void *data, size_t data_len) noexcept {
|
|
||||||
auto impl = std::make_shared<Impl>(CharToString(name), type, shape, data, data_len);
|
MSTensor *MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||||
|
const void *data, size_t data_len) noexcept {
|
||||||
|
auto new_data = malloc(data_len);
|
||||||
|
if (new_data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Allocate data failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
::memcpy(new_data, data, data_len);
|
||||||
|
auto impl = std::shared_ptr<Impl>(Impl::CreateTensorImpl(CharToString(name), type, shape, new_data, data_len));
|
||||||
if (impl == nullptr) {
|
if (impl == nullptr) {
|
||||||
MS_LOG(ERROR) << "Allocate tensor impl failed.";
|
MS_LOG(ERROR) << "Allocate tensor impl failed.";
|
||||||
return MSTensor(nullptr);
|
free(new_data);
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
return MSTensor(impl);
|
auto ms_tensor = new (std::nothrow) MSTensor(impl);
|
||||||
|
if (ms_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Allocate tensor impl failed.";
|
||||||
|
free(new_data);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
impl->set_own_data(true);
|
||||||
|
return ms_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
MSTensor MSTensor::CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
MSTensor *MSTensor::CreateRefTensor(const std::vector<char> &name, enum DataType type,
|
||||||
const void *data, size_t data_len) noexcept {
|
const std::vector<int64_t> &shape, const void *data, size_t data_len) noexcept {
|
||||||
auto tensor = CreateTensor(name, type, shape, data, data_len);
|
auto impl = std::shared_ptr<Impl>(Impl::CreateTensorImpl(CharToString(name), type, shape, data, data_len));
|
||||||
if (tensor == nullptr) {
|
if (impl == nullptr) {
|
||||||
return MSTensor(nullptr);
|
MS_LOG(ERROR) << "Allocate tensor impl failed.";
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
tensor.impl_->set_need_copy(false);
|
auto ms_tensor = new (std::nothrow) MSTensor(impl);
|
||||||
return tensor;
|
if (ms_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Allocate tensor impl failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return ms_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
MSTensor MSTensor::Clone() const {
|
MSTensor *MSTensor::CharStringsToTensor(const std::vector<char> &name, const std::vector<std::vector<char>> &inputs) {
|
||||||
MSTensor ret;
|
auto impl = std::shared_ptr<Impl>(Impl::StringsToTensorImpl(CharToString(name), VectorCharToString(inputs)));
|
||||||
|
if (impl == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Allocate tensor impl failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto ms_tensor = new (std::nothrow) MSTensor(impl);
|
||||||
|
if (ms_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Allocate tensor impl failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return ms_tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<char>> MSTensor::TensorToStringChars(const MSTensor &tensor) {
|
||||||
|
if (tensor.impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Invalid tensor.";
|
||||||
|
std::vector<std::vector<char>> empty;
|
||||||
|
return empty;
|
||||||
|
}
|
||||||
|
return VectorStringToChar(Impl::TensorImplToStrings(tensor.impl_));
|
||||||
|
}
|
||||||
|
|
||||||
|
MSTensor *MSTensor::Clone() const {
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid tensor inpmlement.";
|
MS_LOG(ERROR) << "Invalid tensor.";
|
||||||
ret.impl_ = nullptr;
|
return nullptr;
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
ret.impl_ = impl_->Clone();
|
auto data_len = this->DataSize();
|
||||||
return ret;
|
if (data_len <= 0) {
|
||||||
|
MS_LOG(ERROR) << "Illegal data size of tensor.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto new_data = malloc(data_len);
|
||||||
|
if (new_data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Allocate data failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto impl =
|
||||||
|
std::shared_ptr<Impl>(Impl::CreateTensorImpl(this->Name(), this->DataType(), this->Shape(), new_data, data_len));
|
||||||
|
if (impl == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Allocate tensor impl failed.";
|
||||||
|
free(new_data);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto ms_tensor = new (std::nothrow) MSTensor(impl);
|
||||||
|
if (ms_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Allocate tensor impl failed.";
|
||||||
|
free(new_data);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
::memcpy(new_data, impl_->MutableData(), data_len);
|
||||||
|
impl->set_own_data(true);
|
||||||
|
return ms_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> MSTensor::CharName() const {
|
std::vector<char> MSTensor::CharName() const {
|
||||||
|
@ -160,10 +228,14 @@ bool MSTensor::IsDevice() const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer::Buffer() : impl_(std::make_shared<Impl>()) { MS_LOG(ERROR) << "Unsupported feature."; }
|
void MSTensor::DestroyTensorPtr(MSTensor *tensor) noexcept {
|
||||||
Buffer::Buffer(const void *data, size_t data_len) : impl_(std::make_shared<Impl>(data, data_len)) {
|
if (tensor != nullptr) {
|
||||||
MS_LOG(ERROR) << "Unsupported feature.";
|
delete tensor;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Buffer::Buffer() : impl_(nullptr) { MS_LOG(ERROR) << "Unsupported feature."; }
|
||||||
|
Buffer::Buffer(const void *data, size_t data_len) : impl_(nullptr) { MS_LOG(ERROR) << "Unsupported feature."; }
|
||||||
Buffer::~Buffer() = default;
|
Buffer::~Buffer() = default;
|
||||||
|
|
||||||
Buffer Buffer::Clone() const {
|
Buffer Buffer::Clone() const {
|
||||||
|
|
|
@ -41,8 +41,8 @@ int Executor::CheckInputs(const std::vector<Tensor *> &in_tensors) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int Executor::Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
int Executor::Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||||
const std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator, const KernelCallBack &before,
|
const std::vector<kernel::LiteKernel *> &kernels, mindspore::Allocator *allocator,
|
||||||
const KernelCallBack &after) {
|
const KernelCallBack &before, const KernelCallBack &after) {
|
||||||
MS_ASSERT(nullptr != allocator);
|
MS_ASSERT(nullptr != allocator);
|
||||||
auto ret = this->CheckInputs(in_tensors);
|
auto ret = this->CheckInputs(in_tensors);
|
||||||
if (RET_OK != ret) {
|
if (RET_OK != ret) {
|
||||||
|
|
|
@ -31,7 +31,7 @@ class Executor {
|
||||||
virtual int Prepare(const std::vector<kernel::LiteKernel *> &kernels) { return RET_OK; }
|
virtual int Prepare(const std::vector<kernel::LiteKernel *> &kernels) { return RET_OK; }
|
||||||
|
|
||||||
virtual int Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
virtual int Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||||
const std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator = nullptr,
|
const std::vector<kernel::LiteKernel *> &kernels, mindspore::Allocator *allocator = nullptr,
|
||||||
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr);
|
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -44,7 +44,7 @@ class CpuExecutor : public Executor {
|
||||||
virtual ~CpuExecutor() = default;
|
virtual ~CpuExecutor() = default;
|
||||||
|
|
||||||
int Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
int Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||||
const std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator = nullptr,
|
const std::vector<kernel::LiteKernel *> &kernels, mindspore::Allocator *allocator = nullptr,
|
||||||
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
|
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,7 @@ int InnerContext::Init() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (this->allocator == nullptr) {
|
if (this->allocator == nullptr) {
|
||||||
this->allocator = Allocator::Create();
|
this->allocator = mindspore::Allocator::Create();
|
||||||
if (this->allocator == nullptr) {
|
if (this->allocator == nullptr) {
|
||||||
MS_LOG(ERROR) << "Create Allocator failed";
|
MS_LOG(ERROR) << "Create Allocator failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
|
|
@ -105,7 +105,7 @@ int LiteKernel::PreProcess() {
|
||||||
|
|
||||||
for (auto *output : this->out_tensors()) {
|
for (auto *output : this->out_tensors()) {
|
||||||
MS_ASSERT(output != nullptr);
|
MS_ASSERT(output != nullptr);
|
||||||
if (output->ElementsNum() >= lite::MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) {
|
if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) {
|
||||||
MS_LOG(ERROR) << "The size of output tensor is too big";
|
MS_LOG(ERROR) << "The size of output tensor is too big";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,7 +58,7 @@ int MindrtExecutor::Prepare(const std::vector<kernel::LiteKernel *> &kernels) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int MindrtExecutor::Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
int MindrtExecutor::Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||||
const std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator,
|
const std::vector<kernel::LiteKernel *> &kernels, mindspore::Allocator *allocator,
|
||||||
const KernelCallBack &before, const KernelCallBack &after) {
|
const KernelCallBack &before, const KernelCallBack &after) {
|
||||||
MS_ASSERT(nullptr != allocator);
|
MS_ASSERT(nullptr != allocator);
|
||||||
if (kernels.front()->Type() != schema::PrimitiveType_Merge) {
|
if (kernels.front()->Type() != schema::PrimitiveType_Merge) {
|
||||||
|
|
|
@ -34,7 +34,7 @@ class MindrtExecutor : public Executor {
|
||||||
virtual int Prepare(const std::vector<kernel::LiteKernel *> &kernels);
|
virtual int Prepare(const std::vector<kernel::LiteKernel *> &kernels);
|
||||||
|
|
||||||
virtual int Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
virtual int Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||||
const std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator = nullptr,
|
const std::vector<kernel::LiteKernel *> &kernels, mindspore::Allocator *allocator = nullptr,
|
||||||
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr);
|
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore {
|
||||||
std::shared_ptr<Allocator> Allocator::Create() {
|
std::shared_ptr<Allocator> Allocator::Create() {
|
||||||
return std::shared_ptr<Allocator>(new (std::nothrow) DefaultAllocator());
|
return std::shared_ptr<Allocator>(new (std::nothrow) DefaultAllocator());
|
||||||
}
|
}
|
||||||
|
@ -110,4 +110,4 @@ void DefaultAllocator::Clear() {
|
||||||
freeList_.clear();
|
freeList_.clear();
|
||||||
UnLock();
|
UnLock();
|
||||||
}
|
}
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore
|
||||||
|
|
|
@ -25,7 +25,8 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore {
|
||||||
|
|
||||||
struct AllocatorContext {
|
struct AllocatorContext {
|
||||||
int shiftFactor;
|
int shiftFactor;
|
||||||
bool lockFlag;
|
bool lockFlag;
|
||||||
|
@ -75,6 +76,6 @@ class DefaultAllocator : public Allocator {
|
||||||
constexpr int64_t MAX_MALLOC_SIZE = static_cast<size_t>(2000) * 1024 * 1024;
|
constexpr int64_t MAX_MALLOC_SIZE = static_cast<size_t>(2000) * 1024 * 1024;
|
||||||
constexpr int64_t MAX_THREAD_POOL_SIZE = static_cast<size_t>(3000) * 1024 * 1024;
|
constexpr int64_t MAX_THREAD_POOL_SIZE = static_cast<size_t>(3000) * 1024 * 1024;
|
||||||
|
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_ALLOCATOR_H_
|
#endif // MINDSPORE_LITE_SRC_RUNTIME_ALLOCATOR_H_
|
||||||
|
|
|
@ -64,7 +64,7 @@ class DevKernel {
|
||||||
public:
|
public:
|
||||||
void *data{nullptr};
|
void *data{nullptr};
|
||||||
};
|
};
|
||||||
class GpuAllocator : public Allocator {};
|
class GpuAllocator : public mindspore::Allocator {};
|
||||||
class GpuRuntime {
|
class GpuRuntime {
|
||||||
public:
|
public:
|
||||||
GpuRuntime() {}
|
GpuRuntime() {}
|
||||||
|
|
|
@ -40,7 +40,7 @@ struct ImageSize {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class OpenCLAllocator : public Allocator {
|
class OpenCLAllocator : public mindspore::Allocator {
|
||||||
public:
|
public:
|
||||||
explicit OpenCLAllocator(OpenCLRuntime *ocl_runtime);
|
explicit OpenCLAllocator(OpenCLRuntime *ocl_runtime);
|
||||||
~OpenCLAllocator() override;
|
~OpenCLAllocator() override;
|
||||||
|
|
|
@ -22,13 +22,13 @@
|
||||||
namespace mindspore::lite::opencl {
|
namespace mindspore::lite::opencl {
|
||||||
|
|
||||||
int OpenCLExecutor::Run(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
int OpenCLExecutor::Run(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
||||||
const std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator,
|
const std::vector<kernel::LiteKernel *> &kernels, mindspore::Allocator *allocator,
|
||||||
const KernelCallBack &before, const KernelCallBack &after) {
|
const KernelCallBack &before, const KernelCallBack &after) {
|
||||||
return RunOrTune(inputs, outputs, kernels, allocator, before, after, false);
|
return RunOrTune(inputs, outputs, kernels, allocator, before, after, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
int OpenCLExecutor::RunOrTune(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
int OpenCLExecutor::RunOrTune(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
||||||
const std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator,
|
const std::vector<kernel::LiteKernel *> &kernels, mindspore::Allocator *allocator,
|
||||||
const KernelCallBack &before, const KernelCallBack &after, bool is_tune) {
|
const KernelCallBack &before, const KernelCallBack &after, bool is_tune) {
|
||||||
int ret{RET_OK};
|
int ret{RET_OK};
|
||||||
auto opencl_runtime_ins = ocl_runtime.GetInstance();
|
auto opencl_runtime_ins = ocl_runtime.GetInstance();
|
||||||
|
|
|
@ -32,10 +32,10 @@ class OpenCLExecutor : public Executor {
|
||||||
int Prepare(const std::vector<kernel::LiteKernel *> &kernels) override { return RET_OK; }
|
int Prepare(const std::vector<kernel::LiteKernel *> &kernels) override { return RET_OK; }
|
||||||
|
|
||||||
int Run(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
int Run(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
||||||
const std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator = nullptr,
|
const std::vector<kernel::LiteKernel *> &kernels, mindspore::Allocator *allocator = nullptr,
|
||||||
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
|
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
|
||||||
int RunOrTune(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
int RunOrTune(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
||||||
const std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator = nullptr,
|
const std::vector<kernel::LiteKernel *> &kernels, mindspore::Allocator *allocator = nullptr,
|
||||||
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr, bool is_tune = false);
|
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr, bool is_tune = false);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
|
@ -73,7 +73,7 @@ int GatherFp16CPUKernel::PreProcess() {
|
||||||
|
|
||||||
for (auto *output : this->out_tensors()) {
|
for (auto *output : this->out_tensors()) {
|
||||||
MS_ASSERT(output != nullptr);
|
MS_ASSERT(output != nullptr);
|
||||||
if (output->ElementsNum() >= lite::MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) {
|
if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) {
|
||||||
MS_LOG(ERROR) << "The size of output tensor is too big";
|
MS_LOG(ERROR) << "The size of output tensor is too big";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,7 @@ static int RunKernel(void *data, int index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int ParallelExecutor::Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
int ParallelExecutor::Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||||
const std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator,
|
const std::vector<kernel::LiteKernel *> &kernels, mindspore::Allocator *allocator,
|
||||||
const KernelCallBack &before, const KernelCallBack &after) {
|
const KernelCallBack &before, const KernelCallBack &after) {
|
||||||
MS_ASSERT(nullptr != allocator);
|
MS_ASSERT(nullptr != allocator);
|
||||||
for (auto &inTensor : in_tensors) {
|
for (auto &inTensor : in_tensors) {
|
||||||
|
|
|
@ -33,7 +33,7 @@ class ParallelExecutor : public Executor {
|
||||||
int Prepare(const std::vector<kernel::LiteKernel *> &kernels) override;
|
int Prepare(const std::vector<kernel::LiteKernel *> &kernels) override;
|
||||||
|
|
||||||
int Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
int Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||||
const std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator = nullptr,
|
const std::vector<kernel::LiteKernel *> &kernels, mindspore::Allocator *allocator = nullptr,
|
||||||
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
|
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
|
||||||
inline kernel::LiteKernel *GetReadyKernel(const int index) const { return readyKernels.at(index); }
|
inline kernel::LiteKernel *GetReadyKernel(const int index) const { return readyKernels.at(index); }
|
||||||
inline void SetResult(const int index, const int result) { results.at(index) = result; }
|
inline void SetResult(const int index, const int result) { results.at(index) = result; }
|
||||||
|
|
|
@ -33,9 +33,9 @@ namespace mindspore::kernel {
|
||||||
// store origin data and allocator of input tensor of subgraph for PreProcess and PostProcess
|
// store origin data and allocator of input tensor of subgraph for PreProcess and PostProcess
|
||||||
struct DataStore {
|
struct DataStore {
|
||||||
void *data_ = nullptr;
|
void *data_ = nullptr;
|
||||||
lite::Allocator *allocator_ = nullptr;
|
mindspore::Allocator *allocator_ = nullptr;
|
||||||
static DataStore *CreateDataStore(void *data = nullptr, lite::Allocator *data_allocator = nullptr,
|
static DataStore *CreateDataStore(void *data = nullptr, mindspore::Allocator *data_allocator = nullptr,
|
||||||
lite::Allocator *allocator = nullptr) {
|
mindspore::Allocator *allocator = nullptr) {
|
||||||
DataStore *data_store = nullptr;
|
DataStore *data_store = nullptr;
|
||||||
if (allocator == nullptr) {
|
if (allocator == nullptr) {
|
||||||
data_store = static_cast<DataStore *>(malloc(sizeof(DataStore)));
|
data_store = static_cast<DataStore *>(malloc(sizeof(DataStore)));
|
||||||
|
|
|
@ -29,11 +29,6 @@ namespace lite {
|
||||||
Tensor::Tensor(const TypeId data_type, std::vector<int> shape, const schema::Format &format, Category category)
|
Tensor::Tensor(const TypeId data_type, std::vector<int> shape, const schema::Format &format, Category category)
|
||||||
: data_type_(data_type), shape_(std::move(shape)), format_(format), category_(category) {}
|
: data_type_(data_type), shape_(std::move(shape)), format_(format), category_(category) {}
|
||||||
|
|
||||||
Tensor::Tensor(const std::string &name, enum TypeId type, const std::vector<int32_t> &shape, const void *data)
|
|
||||||
: tensor_name_(name), data_type_(type), shape_(std::move(shape)), category_(VAR) {
|
|
||||||
data_ = const_cast<void *>(data);
|
|
||||||
}
|
|
||||||
|
|
||||||
int Tensor::CopyTensorData(const Tensor &src_tensor, Tensor *dst_tensor) {
|
int Tensor::CopyTensorData(const Tensor &src_tensor, Tensor *dst_tensor) {
|
||||||
if (dst_tensor == nullptr) {
|
if (dst_tensor == nullptr) {
|
||||||
MS_LOG(ERROR) << "dst_tensor is nullptr";
|
MS_LOG(ERROR) << "dst_tensor is nullptr";
|
||||||
|
@ -298,12 +293,12 @@ int Tensor::set_root_tensor(Tensor *tensor) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int Tensor::MallocData(const mindspore::lite::Allocator *allocator) {
|
int Tensor::MallocData(const mindspore::Allocator *allocator) {
|
||||||
if (nullptr != this->data_) {
|
if (nullptr != this->data_) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
if (allocator != nullptr) {
|
if (allocator != nullptr) {
|
||||||
allocator_ = const_cast<mindspore::lite::Allocator *>(allocator);
|
allocator_ = const_cast<mindspore::Allocator *>(allocator);
|
||||||
}
|
}
|
||||||
if (allocator_ == nullptr) {
|
if (allocator_ == nullptr) {
|
||||||
this->data_ = malloc(this->Size());
|
this->data_ = malloc(this->Size());
|
||||||
|
@ -380,5 +375,21 @@ std::vector<tensor::MSTensor *> TensorVectorCast(const std::vector<Tensor *> &sr
|
||||||
std::transform(src.begin(), src.end(), target.begin(), [](Tensor *t) { return dynamic_cast<tensor::MSTensor *>(t); });
|
std::transform(src.begin(), src.end(), target.begin(), [](Tensor *t) { return dynamic_cast<tensor::MSTensor *>(t); });
|
||||||
return target;
|
return target;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
|
|
||||||
|
tensor::MSTensor *tensor::MSTensor::CreateTensor(const std::string &name, TypeId type, const std::vector<int> &shape,
|
||||||
|
const void *data, size_t data_len) {
|
||||||
|
auto tensor = new (std::nothrow) lite::Tensor();
|
||||||
|
if (tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed to allocate tensor.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
tensor->set_data(const_cast<void *>(data));
|
||||||
|
tensor->set_shape(shape);
|
||||||
|
tensor->set_tensor_name(name);
|
||||||
|
tensor->set_data_type(type);
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -58,8 +58,6 @@ class Tensor : public mindspore::tensor::MSTensor {
|
||||||
Tensor(TypeId data_type, std::vector<int> shape, const schema::Format &format = schema::Format::Format_NHWC,
|
Tensor(TypeId data_type, std::vector<int> shape, const schema::Format &format = schema::Format::Format_NHWC,
|
||||||
Category category = VAR);
|
Category category = VAR);
|
||||||
|
|
||||||
Tensor(const std::string &name, enum TypeId type, const std::vector<int32_t> &shape, const void *data);
|
|
||||||
|
|
||||||
Tensor(const Tensor &tensor) = delete;
|
Tensor(const Tensor &tensor) = delete;
|
||||||
|
|
||||||
Tensor(Tensor &&other) = delete;
|
Tensor(Tensor &&other) = delete;
|
||||||
|
@ -86,9 +84,9 @@ class Tensor : public mindspore::tensor::MSTensor {
|
||||||
|
|
||||||
std::vector<int> shape() const override { return shape_; }
|
std::vector<int> shape() const override { return shape_; }
|
||||||
|
|
||||||
void set_shape(const std::vector<int> &shape) { shape_ = shape; }
|
void set_shape(const std::vector<int> &shape) override { shape_ = shape; }
|
||||||
|
|
||||||
int DimensionSize(size_t index) const override;
|
int DimensionSize(size_t index) const;
|
||||||
|
|
||||||
int ElementsNum() const override;
|
int ElementsNum() const override;
|
||||||
|
|
||||||
|
@ -104,16 +102,18 @@ class Tensor : public mindspore::tensor::MSTensor {
|
||||||
|
|
||||||
size_t Size() const override;
|
size_t Size() const override;
|
||||||
|
|
||||||
void set_allocator(mindspore::lite::Allocator *allocator) { allocator_ = allocator; }
|
void set_allocator(mindspore::Allocator *allocator) { allocator_ = allocator; }
|
||||||
|
|
||||||
mindspore::lite::Allocator *allocator() const { return this->allocator_; }
|
mindspore::Allocator *allocator() const { return this->allocator_; }
|
||||||
|
|
||||||
virtual int MallocData(const mindspore::lite::Allocator *allocator = nullptr);
|
virtual int MallocData(const mindspore::Allocator *allocator = nullptr);
|
||||||
|
|
||||||
virtual void FreeData();
|
virtual void FreeData();
|
||||||
|
|
||||||
void *MutableData() override;
|
void *MutableData() override;
|
||||||
|
|
||||||
|
void *data() override { return this->data_; }
|
||||||
|
|
||||||
virtual void *data_c() const {
|
virtual void *data_c() const {
|
||||||
if (this->root_tensor_ != nullptr) {
|
if (this->root_tensor_ != nullptr) {
|
||||||
return this->root_tensor_->data_;
|
return this->root_tensor_->data_;
|
||||||
|
@ -206,7 +206,7 @@ class Tensor : public mindspore::tensor::MSTensor {
|
||||||
size_t init_ref_count_ = 0;
|
size_t init_ref_count_ = 0;
|
||||||
std::vector<QuantArg> quant_params_;
|
std::vector<QuantArg> quant_params_;
|
||||||
std::vector<float> quant_clusters_;
|
std::vector<float> quant_clusters_;
|
||||||
mindspore::lite::Allocator *allocator_ = nullptr;
|
mindspore::Allocator *allocator_ = nullptr;
|
||||||
Tensor *root_tensor_ = nullptr;
|
Tensor *root_tensor_ = nullptr;
|
||||||
bool enable_huffman_code_ = false;
|
bool enable_huffman_code_ = false;
|
||||||
};
|
};
|
||||||
|
|
|
@ -113,9 +113,9 @@ int TensorList::MallocTensorListData(TypeId dtype, const std::vector<std::vector
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int TensorList::MallocData(const mindspore::lite::Allocator *allocator) {
|
int TensorList::MallocData(const mindspore::Allocator *allocator) {
|
||||||
if (allocator != nullptr) {
|
if (allocator != nullptr) {
|
||||||
allocator_ = const_cast<mindspore::lite::Allocator *>(allocator);
|
allocator_ = const_cast<mindspore::Allocator *>(allocator);
|
||||||
}
|
}
|
||||||
// malloc data buf of each tensor in tensors_
|
// malloc data buf of each tensor in tensors_
|
||||||
for (int i = 0; i < this->ElementsNum(); ++i) {
|
for (int i = 0; i < this->ElementsNum(); ++i) {
|
||||||
|
|
|
@ -77,7 +77,7 @@ class TensorList : public Tensor {
|
||||||
|
|
||||||
int MallocTensorListData(TypeId dtype, const std::vector<std::vector<int> > &tensor_shape);
|
int MallocTensorListData(TypeId dtype, const std::vector<std::vector<int> > &tensor_shape);
|
||||||
|
|
||||||
int MallocData(const mindspore::lite::Allocator *allocator = nullptr) override;
|
int MallocData(const mindspore::Allocator *allocator = nullptr) override;
|
||||||
|
|
||||||
int FreeTensorListData();
|
int FreeTensorListData();
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,6 @@
|
||||||
#include "mindspore/lite/src/kernel_registry.h"
|
#include "mindspore/lite/src/kernel_registry.h"
|
||||||
#include "mindspore/lite/src/runtime/allocator.h"
|
#include "mindspore/lite/src/runtime/allocator.h"
|
||||||
|
|
||||||
using mindspore::lite::Allocator;
|
|
||||||
using mindspore::lite::Tensor;
|
using mindspore::lite::Tensor;
|
||||||
using mindspore::schema::ReduceMode;
|
using mindspore::schema::ReduceMode;
|
||||||
using mindspore::schema::ReduceMode_ReduceASum;
|
using mindspore::schema::ReduceMode_ReduceASum;
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "common/common_test.h"
|
#include "common/common_test.h"
|
||||||
#include "include/api/context.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
#if __cplusplus
|
#if __cplusplus
|
||||||
|
@ -58,10 +57,10 @@ void Common::ReadFile(const char *file, size_t *size, char **buf) {
|
||||||
ifs.close();
|
ifs.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Common::ContextAutoSet() {
|
std::shared_ptr<mindspore::Context> Common::ContextAutoSet() {
|
||||||
auto device_target = GetEnv("DEVICE_TARGET");
|
auto device_target_str = GetEnv("DEVICE_TARGET");
|
||||||
if (device_target.empty()) {
|
if (device_target_str.empty()) {
|
||||||
device_target = mindspore::kDeviceTypeAscend310; // default is 310
|
device_target_str = "Ascend310"; // default is 310
|
||||||
}
|
}
|
||||||
|
|
||||||
auto device_id_str = GetEnv("DEVICE_ID");
|
auto device_id_str = GetEnv("DEVICE_ID");
|
||||||
|
@ -69,9 +68,21 @@ void Common::ContextAutoSet() {
|
||||||
device_id_str = "0"; // default is 0
|
device_id_str = "0"; // default is 0
|
||||||
}
|
}
|
||||||
uint32_t device_id = std::strtoul(device_id_str.c_str(), nullptr, 10);
|
uint32_t device_id = std::strtoul(device_id_str.c_str(), nullptr, 10);
|
||||||
|
auto context = std::make_shared<mindspore::Context>();
|
||||||
|
|
||||||
mindspore::GlobalContext::SetGlobalDeviceTarget(device_target);
|
if (device_target_str == "Ascend310") {
|
||||||
mindspore::GlobalContext::SetGlobalDeviceID(device_id);
|
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||||
|
ascend310_info->SetDeviceID(device_id);
|
||||||
|
context->MutableDeviceInfo().emplace_back(ascend310_info);
|
||||||
|
} else if (device_target_str == "Ascend910") {
|
||||||
|
auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
||||||
|
ascend310_info->SetDeviceID(device_id);
|
||||||
|
context->MutableDeviceInfo().emplace_back(ascend310_info);
|
||||||
|
} else {
|
||||||
|
return context;
|
||||||
|
}
|
||||||
|
|
||||||
|
return context;
|
||||||
}
|
}
|
||||||
} // namespace ST
|
} // namespace ST
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
#include "include/api/context.h"
|
||||||
|
|
||||||
namespace ST {
|
namespace ST {
|
||||||
class Common : public testing::Test {
|
class Common : public testing::Test {
|
||||||
public:
|
public:
|
||||||
|
@ -56,7 +58,7 @@ class Common : public testing::Test {
|
||||||
|
|
||||||
void ReadFile(const char *file, size_t *size, char **buf);
|
void ReadFile(const char *file, size_t *size, char **buf);
|
||||||
|
|
||||||
void ContextAutoSet();
|
std::shared_ptr<mindspore::Context> ContextAutoSet();
|
||||||
};
|
};
|
||||||
} // namespace ST
|
} // namespace ST
|
||||||
#endif // TESTS_CXX_ST_COMMON_COMMON_TEST_H_
|
#endif // TESTS_CXX_ST_COMMON_COMMON_TEST_H_
|
||||||
|
|
|
@ -98,6 +98,12 @@ TEST_F(TestDE, TestDvpp) {
|
||||||
ASSERT_TRUE(rc.IsOk());
|
ASSERT_TRUE(rc.IsOk());
|
||||||
auto image = MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_tensor));
|
auto image = MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_tensor));
|
||||||
*/
|
*/
|
||||||
|
auto context = ContextAutoSet();
|
||||||
|
ASSERT_TRUE(context != nullptr);
|
||||||
|
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||||
|
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
||||||
|
ASSERT_TRUE(ascend310_info != nullptr);
|
||||||
|
auto device_id = ascend310_info->GetDeviceID();
|
||||||
|
|
||||||
auto image = ReadFileToTensor("./data/dataset/apple.jpg");
|
auto image = ReadFileToTensor("./data/dataset/apple.jpg");
|
||||||
|
|
||||||
|
@ -105,7 +111,7 @@ TEST_F(TestDE, TestDvpp) {
|
||||||
std::vector<uint32_t> crop_paras = {224, 224};
|
std::vector<uint32_t> crop_paras = {224, 224};
|
||||||
std::vector<uint32_t> resize_paras = {256, 256};
|
std::vector<uint32_t> resize_paras = {256, 256};
|
||||||
std::shared_ptr<TensorTransform> decode_resize_crop(new vision::DvppDecodeResizeCropJpeg(crop_paras, resize_paras));
|
std::shared_ptr<TensorTransform> decode_resize_crop(new vision::DvppDecodeResizeCropJpeg(crop_paras, resize_paras));
|
||||||
mindspore::dataset::Execute Transform(decode_resize_crop, MapTargetDevice::kAscend310);
|
mindspore::dataset::Execute Transform(decode_resize_crop, MapTargetDevice::kAscend310, device_id);
|
||||||
|
|
||||||
// Apply transform on images
|
// Apply transform on images
|
||||||
Status rc = Transform(image, &image);
|
Status rc = Transform(image, &image);
|
||||||
|
@ -145,6 +151,13 @@ TEST_F(TestDE, TestDvpp) {
|
||||||
|
|
||||||
TEST_F(TestDE, TestDvppSinkMode) {
|
TEST_F(TestDE, TestDvppSinkMode) {
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
|
auto context = ContextAutoSet();
|
||||||
|
ASSERT_TRUE(context != nullptr);
|
||||||
|
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||||
|
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
||||||
|
ASSERT_TRUE(ascend310_info != nullptr);
|
||||||
|
auto device_id = ascend310_info->GetDeviceID();
|
||||||
|
|
||||||
// Read images from target directory
|
// Read images from target directory
|
||||||
auto image = ReadFileToTensor("./data/dataset/apple.jpg");
|
auto image = ReadFileToTensor("./data/dataset/apple.jpg");
|
||||||
|
|
||||||
|
@ -155,7 +168,7 @@ TEST_F(TestDE, TestDvppSinkMode) {
|
||||||
std::shared_ptr<TensorTransform> resize(new vision::Resize(resize_paras));
|
std::shared_ptr<TensorTransform> resize(new vision::Resize(resize_paras));
|
||||||
std::shared_ptr<TensorTransform> centercrop(new vision::CenterCrop(crop_paras));
|
std::shared_ptr<TensorTransform> centercrop(new vision::CenterCrop(crop_paras));
|
||||||
std::vector<std::shared_ptr<TensorTransform>> trans_list = {decode, resize, centercrop};
|
std::vector<std::shared_ptr<TensorTransform>> trans_list = {decode, resize, centercrop};
|
||||||
mindspore::dataset::Execute Transform(trans_list, MapTargetDevice::kAscend310);
|
mindspore::dataset::Execute Transform(trans_list, MapTargetDevice::kAscend310, device_id);
|
||||||
|
|
||||||
// Apply transform on images
|
// Apply transform on images
|
||||||
Status rc = Transform(image, &image);
|
Status rc = Transform(image, &image);
|
||||||
|
@ -186,6 +199,13 @@ TEST_F(TestDE, TestDvppSinkMode) {
|
||||||
|
|
||||||
TEST_F(TestDE, TestDvppDecodeResizeCropNormalize) {
|
TEST_F(TestDE, TestDvppDecodeResizeCropNormalize) {
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
|
auto context = ContextAutoSet();
|
||||||
|
ASSERT_TRUE(context != nullptr);
|
||||||
|
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||||
|
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
||||||
|
ASSERT_TRUE(ascend310_info != nullptr);
|
||||||
|
auto device_id = ascend310_info->GetDeviceID();
|
||||||
|
|
||||||
auto image = ReadFileToTensor("./data/dataset/apple.jpg");
|
auto image = ReadFileToTensor("./data/dataset/apple.jpg");
|
||||||
|
|
||||||
// Define dvpp transform
|
// Define dvpp transform
|
||||||
|
@ -200,7 +220,7 @@ TEST_F(TestDE, TestDvppDecodeResizeCropNormalize) {
|
||||||
std::shared_ptr<TensorTransform> normalize(new vision::Normalize(mean, std));
|
std::shared_ptr<TensorTransform> normalize(new vision::Normalize(mean, std));
|
||||||
|
|
||||||
std::vector<std::shared_ptr<TensorTransform>> trans_list = {decode, resize, centercrop, normalize};
|
std::vector<std::shared_ptr<TensorTransform>> trans_list = {decode, resize, centercrop, normalize};
|
||||||
mindspore::dataset::Execute Transform(trans_list, MapTargetDevice::kAscend310);
|
mindspore::dataset::Execute Transform(trans_list, MapTargetDevice::kAscend310, device_id);
|
||||||
|
|
||||||
std::string aipp_cfg = Transform.AippCfgGenerator();
|
std::string aipp_cfg = Transform.AippCfgGenerator();
|
||||||
ASSERT_EQ(aipp_cfg, "./aipp.cfg");
|
ASSERT_EQ(aipp_cfg, "./aipp.cfg");
|
||||||
|
|
|
@ -24,62 +24,68 @@
|
||||||
using namespace mindspore;
|
using namespace mindspore;
|
||||||
|
|
||||||
static const char tensor_add_file[] = "/home/workspace/mindspore_dataset/mindir/add/add.mindir";
|
static const char tensor_add_file[] = "/home/workspace/mindspore_dataset/mindir/add/add.mindir";
|
||||||
static const float input_data_1[2][2] = {{1,2},{3,4}};
|
static const float input_data_1[2][2] = {{1, 2}, {3, 4}};
|
||||||
static const float input_data_2[2][2] = {{2,3},{4,5}};
|
static const float input_data_2[2][2] = {{2, 3}, {4, 5}};
|
||||||
static const float input_data_3[1] ={2};
|
static const float input_data_3[1] = {2};
|
||||||
|
|
||||||
class TestDynamicBatchSize : public ST::Common {
|
class TestDynamicBatchSize : public ST::Common {
|
||||||
public:
|
public:
|
||||||
TestDynamicBatchSize() {}
|
TestDynamicBatchSize() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestDynamicBatchSize, InferMindIR) {
|
TEST_F(TestDynamicBatchSize, InferMindIR) {
|
||||||
mindspore::GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310);
|
#ifdef ENABLE_ACL
|
||||||
mindspore::GlobalContext::SetGlobalDeviceID(2);
|
auto context = ContextAutoSet();
|
||||||
std::map<int,std::vector<int>> input_shape;
|
ASSERT_TRUE(context != nullptr);
|
||||||
input_shape.insert(std::make_pair(0,std::vector<int>{-1,2}));
|
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||||
input_shape.insert(std::make_pair(1,std::vector<int>{-1,2}));
|
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
||||||
auto model_context = std::make_shared<ModelContext>();
|
ASSERT_TRUE(ascend310_info != nullptr);
|
||||||
std::vector<size_t> dynamic_batch_size ={1,2,4,8};
|
|
||||||
ModelContext::SetDynamicBatchSize(model_context,dynamic_batch_size);
|
std::map<int, std::vector<int>> input_shape;
|
||||||
ModelContext::SetInputShapeMap(model_context,input_shape);
|
input_shape.insert(std::make_pair(0, std::vector<int>{-1, 2}));
|
||||||
auto graph = Serialization::LoadModel(tensor_add_file, ModelType::kMindIR);
|
input_shape.insert(std::make_pair(1, std::vector<int>{-1, 2}));
|
||||||
Model tensor_add(GraphCell(graph),model_context);
|
std::vector<size_t> dynamic_batch_size = {1, 2, 4, 8};
|
||||||
ASSERT_TRUE(tensor_add.Build() == kSuccess);
|
ascend310_info->SetDynamicBatchSize(dynamic_batch_size);
|
||||||
|
ascend310_info->SetInputShapeMap(input_shape);
|
||||||
|
|
||||||
|
Graph graph;
|
||||||
|
ASSERT_TRUE(Serialization::Load(tensor_add_file, ModelType::kMindIR, &graph) == kSuccess);
|
||||||
|
Model tensor_add;
|
||||||
|
ASSERT_TRUE(tensor_add.Build(GraphCell(graph), context) == kSuccess);
|
||||||
|
|
||||||
// get model inputs
|
// get model inputs
|
||||||
std::vector<MSTensor> origin_inputs = tensor_add.GetInputs();
|
std::vector<MSTensor> origin_inputs = tensor_add.GetInputs();
|
||||||
ASSERT_EQ(origin_inputs.size()-1, 2);
|
ASSERT_EQ(origin_inputs.size() - 1, 2);
|
||||||
|
|
||||||
// prepare input
|
// prepare input
|
||||||
std::vector<MSTensor> outputs;
|
std::vector<MSTensor> outputs;
|
||||||
std::vector<MSTensor> inputs;
|
std::vector<MSTensor> inputs;
|
||||||
size_t row = sizeof(input_data_1)/sizeof(input_data_1[0]);
|
size_t row = sizeof(input_data_1) / sizeof(input_data_1[0]);
|
||||||
size_t col = sizeof(input_data_1[0])/sizeof(input_data_1[0][0]);;
|
size_t col = sizeof(input_data_1[0]) / sizeof(input_data_1[0][0]);
|
||||||
inputs.emplace_back(origin_inputs[0].Name(), origin_inputs[0].DataType(), origin_inputs[0].Shape(),
|
inputs.emplace_back(origin_inputs[0].Name(), origin_inputs[0].DataType(), origin_inputs[0].Shape(), input_data_1,
|
||||||
input_data_1, sizeof(float) * row*col);
|
sizeof(float) * row * col);
|
||||||
inputs.emplace_back(origin_inputs[1].Name(), origin_inputs[1].DataType(), origin_inputs[1].Shape(),
|
inputs.emplace_back(origin_inputs[1].Name(), origin_inputs[1].DataType(), origin_inputs[1].Shape(), input_data_2,
|
||||||
input_data_2, sizeof(float) * row*col);
|
sizeof(float) * row * col);
|
||||||
inputs.emplace_back(origin_inputs[2].Name(), origin_inputs[2].DataType(), origin_inputs[2].Shape(),
|
inputs.emplace_back(origin_inputs[2].Name(), origin_inputs[2].DataType(), origin_inputs[2].Shape(), input_data_3,
|
||||||
input_data_3, sizeof(float) * 1);
|
sizeof(float) * 1);
|
||||||
|
|
||||||
// infer
|
// infer
|
||||||
ASSERT_TRUE(tensor_add.Predict(inputs, &outputs) == kSuccess);
|
ASSERT_TRUE(tensor_add.Predict(inputs, &outputs) == kSuccess);
|
||||||
|
|
||||||
// assert input
|
// assert input
|
||||||
inputs = tensor_add.GetInputs();
|
inputs = tensor_add.GetInputs();
|
||||||
ASSERT_EQ(inputs.size()-1, 2);
|
ASSERT_EQ(inputs.size() - 1, 2);
|
||||||
auto after_input_data_1 = inputs[0].Data();
|
auto after_input_data_1 = inputs[0].Data();
|
||||||
auto after_input_data_2 = inputs[1].Data();
|
auto after_input_data_2 = inputs[1].Data();
|
||||||
const float *p = reinterpret_cast<const float *>(after_input_data_1.get());
|
const float *p = reinterpret_cast<const float *>(after_input_data_1.get());
|
||||||
float input_data1[inputs[0].DataSize() / sizeof(float)] ={0};
|
float input_data1[inputs[0].DataSize() / sizeof(float)] = {0};
|
||||||
float input_data2[inputs[1].DataSize() / sizeof(float)] ={0};
|
float input_data2[inputs[1].DataSize() / sizeof(float)] = {0};
|
||||||
size_t k=0,t=0;
|
size_t k = 0, t = 0;
|
||||||
for(size_t i=0;i<row;i++)
|
for (size_t i = 0; i < row; i++)
|
||||||
for(size_t j=0;j<col;j++){
|
for (size_t j = 0; j < col; j++) {
|
||||||
input_data1[k++]=input_data_1[i][j];
|
input_data1[k++] = input_data_1[i][j];
|
||||||
input_data2[t++]=input_data_2[i][j];
|
input_data2[t++] = input_data_2[i][j];
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < inputs[0].DataSize() / sizeof(float); ++i) {
|
for (size_t i = 0; i < inputs[0].DataSize() / sizeof(float); ++i) {
|
||||||
ASSERT_LE(std::abs(p[i] - input_data1[i]), 1e-4);
|
ASSERT_LE(std::abs(p[i] - input_data1[i]), 1e-4);
|
||||||
}
|
}
|
||||||
|
@ -96,4 +102,5 @@ TEST_F(TestDynamicBatchSize, InferMindIR) {
|
||||||
ASSERT_LE(std::abs(p[i] - (input_data1[i] + input_data2[i])), 1e-4);
|
ASSERT_LE(std::abs(p[i] - (input_data1[i] + input_data2[i])), 1e-4);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif // ENABLE_ACL
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,11 +32,12 @@ class TestAdd : public ST::Common {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestAdd, InferMindIR) {
|
TEST_F(TestAdd, InferMindIR) {
|
||||||
ContextAutoSet();
|
auto context = ContextAutoSet();
|
||||||
|
|
||||||
auto graph = Serialization::LoadModel(tensor_add_file, ModelType::kMindIR);
|
Graph graph;
|
||||||
Model tensor_add((GraphCell(graph)));
|
ASSERT_TRUE(Serialization::Load(tensor_add_file, ModelType::kMindIR, &graph));
|
||||||
ASSERT_TRUE(tensor_add.Build() == kSuccess);
|
Model tensor_add;
|
||||||
|
ASSERT_TRUE(tensor_add.Build(GraphCell(graph), context) == kSuccess);
|
||||||
|
|
||||||
// get model inputs
|
// get model inputs
|
||||||
std::vector<MSTensor> origin_inputs = tensor_add.GetInputs();
|
std::vector<MSTensor> origin_inputs = tensor_add.GetInputs();
|
||||||
|
|
|
@ -51,46 +51,49 @@ std::vector<std::string> GetAllFiles(std::string_view dir_name);
|
||||||
|
|
||||||
TEST_F(TestZeroCopy, TestMindIR) {
|
TEST_F(TestZeroCopy, TestMindIR) {
|
||||||
#ifdef ENABLE_ACL
|
#ifdef ENABLE_ACL
|
||||||
// Set context
|
// Set context
|
||||||
mindspore::GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310);
|
auto context = ContextAutoSet();
|
||||||
mindspore::GlobalContext::SetGlobalDeviceID(0);
|
ASSERT_TRUE(context != nullptr);
|
||||||
auto model_context = std::make_shared<ModelContext>();
|
ASSERT_TRUE(context->MutableDeviceInfo().size() == 1);
|
||||||
ModelContext::SetInsertOpConfigPath(model_context,aipp_path);
|
auto ascend310_info = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
||||||
// Define model
|
ASSERT_TRUE(ascend310_info != nullptr);
|
||||||
auto graph = mindspore::Serialization::LoadModel(resnet_file, mindspore::ModelType::kMindIR);
|
ascend310_info->SetInsertOpConfigPath(aipp_path);
|
||||||
mindspore::Model resnet50(mindspore::GraphCell(graph),model_context);
|
auto device_id = ascend310_info->GetDeviceID();
|
||||||
// Build model
|
// Define model
|
||||||
ASSERT_TRUE(resnet50.Build() == kSuccess);
|
Graph graph;
|
||||||
// Get model info
|
ASSERT_TRUE(Serialization::Load(resnet_file, ModelType::kMindIR, &graph) == kSuccess);
|
||||||
std::vector<mindspore::MSTensor> model_inputs =resnet50.GetInputs();
|
Model resnet50;
|
||||||
ASSERT_EQ(model_inputs.size(), 1);
|
ASSERT_TRUE(resnet50.Build(GraphCell(graph), context) == kSuccess);
|
||||||
// Define transform operations
|
// Get model info
|
||||||
std::shared_ptr<TensorTransform> decode(new vision::Decode());
|
std::vector<mindspore::MSTensor> model_inputs = resnet50.GetInputs();
|
||||||
std::shared_ptr<TensorTransform> resize(new vision::Resize({256}));
|
ASSERT_EQ(model_inputs.size(), 1);
|
||||||
std::shared_ptr<TensorTransform> center_crop(new vision::CenterCrop({224,224}));
|
// Define transform operations
|
||||||
mindspore::dataset::Execute Transform({decode,resize,center_crop},MapTargetDevice::kAscend310);
|
std::shared_ptr<TensorTransform> decode(new vision::Decode());
|
||||||
size_t count=0;
|
std::shared_ptr<TensorTransform> resize(new vision::Resize({256}));
|
||||||
// Read images
|
std::shared_ptr<TensorTransform> center_crop(new vision::CenterCrop({224, 224}));
|
||||||
std::vector<std::string> images =GetAllFiles(image_path);
|
mindspore::dataset::Execute Transform({decode, resize, center_crop}, MapTargetDevice::kAscend310, device_id);
|
||||||
for(const auto &image_file:images){
|
size_t count = 0;
|
||||||
// prepare input
|
// Read images
|
||||||
std::vector<mindspore::MSTensor> inputs;
|
std::vector<std::string> images = GetAllFiles(image_path);
|
||||||
std::vector<mindspore::MSTensor> outputs;
|
for (const auto &image_file : images) {
|
||||||
std::shared_ptr<mindspore::dataset::Tensor> de_tensor;
|
// prepare input
|
||||||
mindspore::dataset::Tensor::CreateFromFile(image_file, &de_tensor);
|
std::vector<mindspore::MSTensor> inputs;
|
||||||
auto image = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_tensor));
|
std::vector<mindspore::MSTensor> outputs;
|
||||||
// Apply transform on images
|
std::shared_ptr<mindspore::dataset::Tensor> de_tensor;
|
||||||
Status rc = Transform(image, &image);
|
mindspore::dataset::Tensor::CreateFromFile(image_file, &de_tensor);
|
||||||
ASSERT_TRUE(rc == kSuccess);
|
auto image = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_tensor));
|
||||||
inputs.push_back(image);
|
// Apply transform on images
|
||||||
// infer
|
Status rc = Transform(image, &image);
|
||||||
ASSERT_TRUE(resnet50.Predict(inputs, &outputs)==kSuccess);
|
ASSERT_TRUE(rc == kSuccess);
|
||||||
if(GetMax(outputs[0])==0){
|
inputs.push_back(image);
|
||||||
++count;
|
// infer
|
||||||
}
|
ASSERT_TRUE(resnet50.Predict(inputs, &outputs) == kSuccess);
|
||||||
Transform.DeviceMemoryRelease();
|
if (GetMax(outputs[0]) == 0) {
|
||||||
|
++count;
|
||||||
}
|
}
|
||||||
ASSERT_GE(static_cast<double>(count)/images.size()*100.0, 20.0);
|
Transform.DeviceMemoryRelease();
|
||||||
|
}
|
||||||
|
ASSERT_GE(static_cast<double>(count) / images.size() * 100.0, 20.0);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,8 +152,7 @@ std::vector<std::string> GetAllFiles(std::string_view dir_name) {
|
||||||
while ((filename = readdir(dir)) != nullptr) {
|
while ((filename = readdir(dir)) != nullptr) {
|
||||||
std::string d_name = std::string(filename->d_name);
|
std::string d_name = std::string(filename->d_name);
|
||||||
// get rid of "." and ".."
|
// get rid of "." and ".."
|
||||||
if (d_name == "." || d_name == ".." || filename->d_type != DT_REG)
|
if (d_name == "." || d_name == ".." || filename->d_type != DT_REG) continue;
|
||||||
continue;
|
|
||||||
res.emplace_back(std::string(dir_name) + "/" + filename->d_name);
|
res.emplace_back(std::string(dir_name) + "/" + filename->d_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,45 +23,102 @@ class TestCxxApiContext : public UT::Common {
|
||||||
TestCxxApiContext() = default;
|
TestCxxApiContext() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestCxxApiContext, test_context_global_context_SUCCESS) {
|
TEST_F(TestCxxApiContext, test_context_device_info_cast_SUCCESS) {
|
||||||
std::string device_target = "2333";
|
std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>();
|
||||||
uint32_t device_id = 2333;
|
std::shared_ptr<DeviceInfoContext> mali_gpu = std::make_shared<MaliGPUDeviceInfo>();
|
||||||
GlobalContext::SetGlobalDeviceTarget(device_target);
|
std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>();
|
||||||
ASSERT_EQ(GlobalContext::GetGlobalDeviceTarget(), device_target);
|
std::shared_ptr<DeviceInfoContext> nvidia_gpu = std::make_shared<NvidiaGPUDeviceInfo>();
|
||||||
GlobalContext::SetGlobalDeviceID(device_id);
|
std::shared_ptr<DeviceInfoContext> ascend310 = std::make_shared<Ascend310DeviceInfo>();
|
||||||
ASSERT_EQ(GlobalContext::GetGlobalDeviceID(), device_id);
|
std::shared_ptr<DeviceInfoContext> ascend910 = std::make_shared<Ascend910DeviceInfo>();
|
||||||
|
|
||||||
|
ASSERT_TRUE(cpu->Cast<CPUDeviceInfo>() != nullptr);
|
||||||
|
ASSERT_TRUE(mali_gpu->Cast<MaliGPUDeviceInfo>() != nullptr);
|
||||||
|
ASSERT_TRUE(kirin_npu->Cast<KirinNPUDeviceInfo>() != nullptr);
|
||||||
|
ASSERT_TRUE(nvidia_gpu->Cast<NvidiaGPUDeviceInfo>() != nullptr);
|
||||||
|
ASSERT_TRUE(ascend310->Cast<Ascend310DeviceInfo>() != nullptr);
|
||||||
|
ASSERT_TRUE(ascend910->Cast<Ascend910DeviceInfo>() != nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestCxxApiContext, test_context_ascend310_context_SUCCESS) {
|
TEST_F(TestCxxApiContext, test_context_device_info_cast_FAILED) {
|
||||||
|
std::shared_ptr<DeviceInfoContext> cpu = std::make_shared<CPUDeviceInfo>();
|
||||||
|
std::shared_ptr<DeviceInfoContext> mali_gpu = std::make_shared<MaliGPUDeviceInfo>();
|
||||||
|
std::shared_ptr<DeviceInfoContext> kirin_npu = std::make_shared<KirinNPUDeviceInfo>();
|
||||||
|
std::shared_ptr<DeviceInfoContext> nvidia_gpu = std::make_shared<NvidiaGPUDeviceInfo>();
|
||||||
|
std::shared_ptr<DeviceInfoContext> ascend310 = std::make_shared<Ascend310DeviceInfo>();
|
||||||
|
std::shared_ptr<DeviceInfoContext> ascend910 = std::make_shared<Ascend910DeviceInfo>();
|
||||||
|
|
||||||
|
ASSERT_TRUE(cpu->Cast<MaliGPUDeviceInfo>() == nullptr);
|
||||||
|
ASSERT_TRUE(kirin_npu->Cast<MaliGPUDeviceInfo>() == nullptr);
|
||||||
|
ASSERT_TRUE(nvidia_gpu->Cast<MaliGPUDeviceInfo>() == nullptr);
|
||||||
|
ASSERT_TRUE(ascend310->Cast<MaliGPUDeviceInfo>() == nullptr);
|
||||||
|
ASSERT_TRUE(ascend910->Cast<MaliGPUDeviceInfo>() == nullptr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(mali_gpu->Cast<CPUDeviceInfo>() == nullptr);
|
||||||
|
ASSERT_TRUE(kirin_npu->Cast<CPUDeviceInfo>() == nullptr);
|
||||||
|
ASSERT_TRUE(nvidia_gpu->Cast<CPUDeviceInfo>() == nullptr);
|
||||||
|
ASSERT_TRUE(ascend310->Cast<CPUDeviceInfo>() == nullptr);
|
||||||
|
ASSERT_TRUE(ascend910->Cast<CPUDeviceInfo>() == nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiContext, test_context_get_set_SUCCESS) {
|
||||||
|
int32_t thread_num = 22;
|
||||||
|
auto context = std::make_shared<Context>();
|
||||||
|
context->SetThreadNum(thread_num);
|
||||||
|
ASSERT_EQ(context->GetThreadNum(), thread_num);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiContext, test_context_cpu_context_SUCCESS) {
|
||||||
|
auto context = std::make_shared<Context>();
|
||||||
|
std::shared_ptr<CPUDeviceInfo> cpu = std::make_shared<CPUDeviceInfo>();
|
||||||
|
cpu->SetEnableFP16(true);
|
||||||
|
context->MutableDeviceInfo().push_back(cpu);
|
||||||
|
ASSERT_EQ(context->MutableDeviceInfo().size(), 1);
|
||||||
|
auto cpu_2 = context->MutableDeviceInfo()[0]->Cast<CPUDeviceInfo>();
|
||||||
|
ASSERT_TRUE(cpu_2 != nullptr);
|
||||||
|
ASSERT_TRUE(cpu_2->GetEnableFP16());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) {
|
||||||
std::string option_1 = "aaa";
|
std::string option_1 = "aaa";
|
||||||
std::string option_2 = "vvv";
|
std::string option_2 = "vvv";
|
||||||
std::string option_3 = "www";
|
std::string option_3 = "www";
|
||||||
auto option_4 = DataType::kNumberTypeEnd;
|
std::string option_4 = "rrr";
|
||||||
std::string option_5 = "rrr";
|
std::string option_5 = "ppp";
|
||||||
std::string option_6 = "ppp";
|
std::string option_6 = "sss";
|
||||||
auto ctx = std::make_shared<ModelContext>();
|
uint32_t option_7 = 77;
|
||||||
ModelContext::SetInsertOpConfigPath(ctx, option_1);
|
enum DataType option_8 = DataType::kNumberTypeInt16;
|
||||||
ModelContext::SetInputFormat(ctx, option_2);
|
std::vector<size_t> option_9 = {1, 2, 3, 4, 5};
|
||||||
ModelContext::SetInputShape(ctx, option_3);
|
std::string option_9_ans = "1,2,3,4,5";
|
||||||
ModelContext::SetOutputType(ctx, option_4);
|
|
||||||
ModelContext::SetPrecisionMode(ctx, option_5);
|
|
||||||
ModelContext::SetOpSelectImplMode(ctx, option_6);
|
|
||||||
|
|
||||||
ASSERT_EQ(ModelContext::GetInsertOpConfigPath(ctx), option_1);
|
auto context = std::make_shared<Context>();
|
||||||
ASSERT_EQ(ModelContext::GetInputFormat(ctx), option_2);
|
std::shared_ptr<Ascend310DeviceInfo> ascend310 = std::make_shared<Ascend310DeviceInfo>();
|
||||||
ASSERT_EQ(ModelContext::GetInputShape(ctx), option_3);
|
ascend310->SetInputShape(option_1);
|
||||||
ASSERT_EQ(ModelContext::GetOutputType(ctx), option_4);
|
ascend310->SetInsertOpConfigPath(option_2);
|
||||||
ASSERT_EQ(ModelContext::GetPrecisionMode(ctx), option_5);
|
ascend310->SetOpSelectImplMode(option_3);
|
||||||
ASSERT_EQ(ModelContext::GetOpSelectImplMode(ctx), option_6);
|
ascend310->SetPrecisionMode(option_4);
|
||||||
}
|
ascend310->SetInputFormat(option_5);
|
||||||
|
ascend310->SetFusionSwitchConfigPath(option_6);
|
||||||
|
ascend310->SetDeviceID(option_7);
|
||||||
|
ascend310->SetOutputType(option_8);
|
||||||
|
ascend310->SetDynamicBatchSize(option_9);
|
||||||
|
|
||||||
TEST_F(TestCxxApiContext, test_context_ascend310_context_nullptr_FAILED) {
|
context->MutableDeviceInfo().push_back(ascend310);
|
||||||
auto ctx = std::make_shared<ModelContext>();
|
ASSERT_EQ(context->MutableDeviceInfo().size(), 1);
|
||||||
EXPECT_ANY_THROW(ModelContext::GetInsertOpConfigPath(nullptr));
|
auto ctx = context->MutableDeviceInfo()[0]->Cast<Ascend310DeviceInfo>();
|
||||||
|
ASSERT_TRUE(ctx != nullptr);
|
||||||
|
ASSERT_EQ(ascend310->GetInputShape(), option_1);
|
||||||
|
ASSERT_EQ(ascend310->GetInsertOpConfigPath(), option_2);
|
||||||
|
ASSERT_EQ(ascend310->GetOpSelectImplMode(), option_3);
|
||||||
|
ASSERT_EQ(ascend310->GetPrecisionMode(), option_4);
|
||||||
|
ASSERT_EQ(ascend310->GetInputFormat(), option_5);
|
||||||
|
ASSERT_EQ(ascend310->GetFusionSwitchConfigPath(), option_6);
|
||||||
|
ASSERT_EQ(ascend310->GetDeviceID(), option_7);
|
||||||
|
ASSERT_EQ(ascend310->GetOutputType(), option_8);
|
||||||
|
ASSERT_EQ(ascend310->GetDynamicBatchSize(), option_9_ans);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestCxxApiContext, test_context_ascend310_context_default_value_SUCCESS) {
|
TEST_F(TestCxxApiContext, test_context_ascend310_context_default_value_SUCCESS) {
|
||||||
auto ctx = std::make_shared<ModelContext>();
|
auto ctx = std::make_shared<Ascend310DeviceInfo>();
|
||||||
ASSERT_EQ(ModelContext::GetOpSelectImplMode(ctx), "");
|
ASSERT_EQ(ctx->GetOpSelectImplMode(), "");
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -76,12 +76,13 @@ TEST_F(TestCxxApiTypes, test_tensor_ref_SUCCESS) {
|
||||||
TEST_F(TestCxxApiTypes, test_tensor_clone_SUCCESS) {
|
TEST_F(TestCxxApiTypes, test_tensor_clone_SUCCESS) {
|
||||||
std::vector<int32_t> data = {1, 2, 3, 4};
|
std::vector<int32_t> data = {1, 2, 3, 4};
|
||||||
MSTensor tensor("", DataType::kNumberTypeInt32, {4}, data.data(), data.size() * sizeof(int32_t));
|
MSTensor tensor("", DataType::kNumberTypeInt32, {4}, data.data(), data.size() * sizeof(int32_t));
|
||||||
MSTensor tensor2 = tensor.Clone();
|
MSTensor *tensor2 = tensor.Clone();
|
||||||
auto value = tensor2.Data();
|
auto value = tensor2->Data();
|
||||||
int32_t *p = (int32_t *)value.get();
|
int32_t *p = (int32_t *)value.get();
|
||||||
for (size_t i = 0; i < data.size(); ++i) {
|
for (size_t i = 0; i < data.size(); ++i) {
|
||||||
ASSERT_EQ(p[i], data[i]);
|
ASSERT_EQ(p[i], data[i]);
|
||||||
}
|
}
|
||||||
|
MSTensor::DestroyTensorPtr(tensor2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestCxxApiTypes, test_tensor_ref_modified_SUCCESS) {
|
TEST_F(TestCxxApiTypes, test_tensor_ref_modified_SUCCESS) {
|
||||||
|
@ -101,37 +102,76 @@ TEST_F(TestCxxApiTypes, test_tensor_clone_modified_SUCCESS) {
|
||||||
std::vector<int32_t> data = {1, 2, 3, 4};
|
std::vector<int32_t> data = {1, 2, 3, 4};
|
||||||
std::vector<int32_t> data_modified = {2, 3, 4, 5};
|
std::vector<int32_t> data_modified = {2, 3, 4, 5};
|
||||||
MSTensor tensor("", DataType::kNumberTypeInt32, {4}, data.data(), data.size() * sizeof(int32_t));
|
MSTensor tensor("", DataType::kNumberTypeInt32, {4}, data.data(), data.size() * sizeof(int32_t));
|
||||||
MSTensor tensor2 = tensor.Clone();
|
MSTensor *tensor2 = tensor.Clone();
|
||||||
|
ASSERT_TRUE(tensor2 != nullptr);
|
||||||
(void)memcpy(tensor.MutableData(), data_modified.data(), data_modified.size() * sizeof(int32_t));
|
(void)memcpy(tensor.MutableData(), data_modified.data(), data_modified.size() * sizeof(int32_t));
|
||||||
auto value = tensor2.Data();
|
auto value = tensor2->Data();
|
||||||
int32_t *p = (int32_t *)value.get();
|
int32_t *p = (int32_t *)value.get();
|
||||||
for (size_t i = 0; i < data.size(); ++i) {
|
for (size_t i = 0; i < data.size(); ++i) {
|
||||||
ASSERT_EQ(p[i], data[i]);
|
ASSERT_EQ(p[i], data[i]);
|
||||||
}
|
}
|
||||||
|
MSTensor::DestroyTensorPtr(tensor2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestCxxApiTypes, test_tensor_ref_creator_function_SUCCESS) {
|
TEST_F(TestCxxApiTypes, test_tensor_ref_creator_function_SUCCESS) {
|
||||||
std::vector<int32_t> data = {1, 2, 3, 4};
|
std::vector<int32_t> data = {1, 2, 3, 4};
|
||||||
MSTensor tensor =
|
MSTensor *tensor =
|
||||||
MSTensor::CreateRefTensor("", DataType::kNumberTypeInt32, {4}, data.data(), data.size() * sizeof(int32_t));
|
MSTensor::CreateRefTensor("", DataType::kNumberTypeInt32, {4}, data.data(), data.size() * sizeof(int32_t));
|
||||||
|
ASSERT_TRUE(tensor != nullptr);
|
||||||
data = {3, 4, 5, 6};
|
data = {3, 4, 5, 6};
|
||||||
auto value = tensor.Data();
|
auto value = tensor->Data();
|
||||||
int32_t *p = (int32_t *)value.get();
|
int32_t *p = (int32_t *)value.get();
|
||||||
for (size_t i = 0; i < data.size(); ++i) {
|
for (size_t i = 0; i < data.size(); ++i) {
|
||||||
ASSERT_EQ(p[i], data[i]);
|
ASSERT_EQ(p[i], data[i]);
|
||||||
}
|
}
|
||||||
|
MSTensor::DestroyTensorPtr(tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestCxxApiTypes, test_tensor_creator_function_SUCCESS) {
|
TEST_F(TestCxxApiTypes, test_tensor_creator_function_SUCCESS) {
|
||||||
std::vector<int32_t> data = {1, 2, 3, 4};
|
std::vector<int32_t> data = {1, 2, 3, 4};
|
||||||
MSTensor tensor =
|
MSTensor *tensor =
|
||||||
MSTensor::CreateTensor("", DataType::kNumberTypeInt32, {4}, data.data(), data.size() * sizeof(int32_t));
|
MSTensor::CreateTensor("", DataType::kNumberTypeInt32, {4}, data.data(), data.size() * sizeof(int32_t));
|
||||||
|
ASSERT_TRUE(tensor != nullptr);
|
||||||
data = {3, 4, 5, 6};
|
data = {3, 4, 5, 6};
|
||||||
auto value = tensor.Data();
|
auto value = tensor->Data();
|
||||||
int32_t *p = (int32_t *)value.get();
|
int32_t *p = (int32_t *)value.get();
|
||||||
for (size_t i = 0; i < data.size(); ++i) {
|
for (size_t i = 0; i < data.size(); ++i) {
|
||||||
ASSERT_NE(p[i], data[i]);
|
ASSERT_NE(p[i], data[i]);
|
||||||
}
|
}
|
||||||
|
MSTensor::DestroyTensorPtr(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiTypes, test_tensor_string_tensor_SUCCESS) {
|
||||||
|
std::string tensor_name = "tensor_name";
|
||||||
|
std::vector<std::string> origin_strs;
|
||||||
|
origin_strs.emplace_back("qwe");
|
||||||
|
origin_strs.emplace_back("asd");
|
||||||
|
origin_strs.emplace_back("");
|
||||||
|
origin_strs.emplace_back("zxc");
|
||||||
|
auto tensor = MSTensor::StringsToTensor(tensor_name, origin_strs);
|
||||||
|
ASSERT_TRUE(tensor != nullptr);
|
||||||
|
ASSERT_EQ(tensor->Name(), tensor_name);
|
||||||
|
auto new_strs = MSTensor::TensorToStrings(*tensor);
|
||||||
|
ASSERT_EQ(new_strs.size(), origin_strs.size());
|
||||||
|
for (size_t i = 0; i < new_strs.size(); ++i) {
|
||||||
|
ASSERT_EQ(new_strs[i], origin_strs[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiTypes, test_tensor_empty_string_tensor_SUCCESS) {
|
||||||
|
std::string tensor_name = "tensor_name";
|
||||||
|
std::vector<std::string> origin_strs;
|
||||||
|
auto tensor = MSTensor::StringsToTensor(tensor_name, origin_strs);
|
||||||
|
ASSERT_TRUE(tensor != nullptr);
|
||||||
|
ASSERT_EQ(tensor->Name(), tensor_name);
|
||||||
|
auto new_strs = MSTensor::TensorToStrings(*tensor);
|
||||||
|
ASSERT_EQ(new_strs.size(), origin_strs.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestCxxApiTypes, test_tensor_string_tensor_invalid_type_FAILED) {
|
||||||
|
MSTensor tensor("", DataType::kNumberTypeInt32, {1}, nullptr, sizeof(int32_t));
|
||||||
|
auto new_strs = MSTensor::TensorToStrings(tensor);
|
||||||
|
ASSERT_TRUE(new_strs.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestCxxApiTypes, test_buffer_data_ref_and_copy_SUCCESS) {
|
TEST_F(TestCxxApiTypes, test_buffer_data_ref_and_copy_SUCCESS) {
|
||||||
|
|
Loading…
Reference in New Issue