[MSLITE][DEVELOP] modify delegate return value from status to value, add notes for api

This commit is contained in:
yangruoqi713 2021-08-19 09:46:29 +08:00
parent e187eb15d9
commit 7d07601aeb
18 changed files with 243 additions and 98 deletions

View File

@ -44,27 +44,57 @@ class MS_API Context {
Context();
~Context() = default;
/// \brief Set the number of threads at runtime. This option is only valid for MindSpore Lite.
/// \brief Set the number of threads at runtime. Only valid for Lite.
///
/// \param[in] thread_num the number of threads at runtime.
void SetThreadNum(int32_t thread_num);
/// \brief Get the current thread number setting.
/// \brief Get the current thread number setting. Only valid for Lite.
///
/// \return The current thread number setting.
int32_t GetThreadNum() const;
/// \brief Set the thread affinity to CPU cores.
/// \brief Set the thread affinity to CPU cores. Only valid for Lite.
///
/// \param mode: 0: no affinities, 1: big cores first, 2: little cores first
/// \param[in] mode: 0: no affinities, 1: big cores first, 2: little cores first
void SetThreadAffinity(int mode);
/// \brief Get the thread affinity of CPU cores. Only valid for Lite.
///
/// \return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first
int GetThreadAffinityMode() const;
/// \brief Set the thread lists to CPU cores. Only valid for Lite.
///
/// \note If core_list and mode are set by SetThreadAffinity at the same time, the core_list is effective, but the
/// mode is not effective.
///
/// \param[in] core_list: a vector of thread core lists.
void SetThreadAffinity(const std::vector<int> &core_list);
/// \brief Get the thread lists of CPU cores. Only valid for Lite.
///
/// \return core_list: a vector of thread core lists.
std::vector<int32_t> GetThreadAffinityCoreList() const;
/// \brief Set the status whether to perform model inference or training in parallel. Only valid for Lite.
///
/// \param[in] is_parallel: true, parallel; false, not in parallel.
void SetEnableParallel(bool is_parallel);
/// \brief Get the status whether to perform model inference or training in parallel. Only valid for Lite.
///
/// \return Bool value that indicates whether in parallel.
bool GetEnableParallel() const;
/// \brief Set Delegate to access third-party AI framework. Only valid for Lite.
///
/// \param[in] Pointer to the custom delegate.
void SetDelegate(const std::shared_ptr<Delegate> &delegate);
/// \brief Get the delegate of the third-party AI framework. Only valid for Lite.
///
/// \return Pointer to the custom delegate.
std::shared_ptr<Delegate> GetDelegate() const;
/// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports
@ -112,19 +142,23 @@ class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoC
/// \brief set provider's name.
///
/// \param[in] provider define the provider's name.
void SetProvider(const std::string &provider);
/// \brief obtain provider's device type.
///
/// \return provider's device type.
std::string GetProviderDevice() const;
/// \brief set provider's device type.
///
/// \param[in] device define the provider's device type.EG: CPU.
void SetProviderDevice(const std::string &device);
/// \brief set memory allocator.
///
/// \param[in] allocator define the memory allocator which can be defined by user.
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
/// \brief obtain memory allocator.
///
/// \return memory allocator.
@ -147,6 +181,7 @@ class MS_API CPUDeviceInfo : public DeviceInfoContext {
///
/// \param[in] is_fp16 Enable float16 inference or not.
void SetEnableFP16(bool is_fp16);
/// \brief Get enables to perform the float16 inference
///
/// \return Whether enable float16 inference.
@ -167,6 +202,7 @@ class MS_API KirinNPUDeviceInfo : public DeviceInfoContext {
/// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme
/// performance), default as 3.
void SetFrequency(int frequency);
/// \brief Get the NPU frequency.
///
/// \return NPU frequency
@ -185,6 +221,7 @@ class MS_API GPUDeviceInfo : public DeviceInfoContext {
///
/// \param[in] device_id The device id.
void SetDeviceID(uint32_t device_id);
/// \brief Get the device id.
///
/// \return The device id.
@ -200,6 +237,7 @@ class MS_API GPUDeviceInfo : public DeviceInfoContext {
///
/// \param[in] is_fp16 Enable float16 inference or not.
void SetEnableFP16(bool is_fp16);
/// \brief Get enables to perform the float16 inference
///
/// \return Whether enable float16 inference.
@ -228,6 +266,7 @@ class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
///
/// \param[in] device_id The device id.
void SetDeviceID(uint32_t device_id);
/// \brief Get the device id.
///
/// \return The device id.
@ -247,6 +286,7 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
///
/// \param[in] device_id The device id.
void SetDeviceID(uint32_t device_id);
/// \brief Get the device id.
///
/// \return The device id.
@ -259,6 +299,7 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
///
/// \param[in] cfg_path AIPP configuration file path.
inline void SetInsertOpConfigPath(const std::string &cfg_path);
/// \brief Get AIPP configuration file path.
///
/// \return AIPP configuration file path.
@ -268,6 +309,7 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
///
/// \param[in] format Optional "NCHW", "NHWC", etc.
inline void SetInputFormat(const std::string &format);
/// \brief Get format of model inputs.
///
/// \return The format of model inputs.
@ -277,6 +319,7 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
///
/// \param[in] shape e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1".
inline void SetInputShape(const std::string &shape);
/// \brief Get shape of model inputs.
///
/// \return The shape of model inputs.
@ -287,6 +330,7 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
/// \param[in] shape e.g. {{1, {1,2,3,4}}, {2, {4,3,2,1}}} means the first input shape 1,2,3,4 and the second input
/// shape 4,3,2,1.
void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
/// \brief Get shape of model inputs.
///
/// \return The shape of model inputs.
@ -299,6 +343,7 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
///
/// \param[in] output_type FP32, UINT8 or FP16, default as FP32.
void SetOutputType(enum DataType output_type);
/// \brief Get type of model outputs.
///
/// \return The set type of model outputs.
@ -309,6 +354,7 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
/// \param[in] precision_mode Optional "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" and
/// "allow_mix_precision", "force_fp16" is set as default
inline void SetPrecisionMode(const std::string &precision_mode);
/// \brief Get precision mode of model.
///
/// \return The set type of model outputs
@ -319,6 +365,7 @@ class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
/// \param[in] op_select_impl_mode Optional "high_performance" and "high_precision", "high_performance" is set as
/// default.
inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
/// \brief Get op select implementation mode.
///
/// \return The set op select implementation mode.

View File

@ -22,6 +22,7 @@
#include <memory>
#include "schema/model_generated.h"
#include "include/api/kernel.h"
#include "include/api/status.h"
namespace mindspore {
typedef enum {
@ -100,15 +101,17 @@ class MS_API Delegate {
/// \brief Init delegate.
///
/// \note Init willed be called in CreateSession.
virtual int Init() = 0;
/// \note Init willed be called in Model::Build.
///
/// \return Status. If Status is kLiteNotSupport, the program will return to the MindSpore Lite inner inference.
virtual Status Init() = 0;
/// \brief Build delegate graph for MindSpore Lite model.
///
/// \note Build willed be called in LiteSession::CompileGraph.
/// \note Build willed be called in Model::Build.
///
/// \param[in] model Define the delegate model to be built.
virtual int Build(DelegateModel *model) = 0;
virtual Status Build(DelegateModel *model) = 0;
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_DELEGATE_H

View File

@ -44,6 +44,7 @@ class MS_API Model {
~Model();
Model(const Model &) = delete;
void operator=(const Model &) = delete;
/// \brief Builds a model so that it can run on a device.
///
/// \param[in] graph GraphCell is a derivative of Cell. Cell is not available currently. GraphCell can be constructed
@ -78,6 +79,7 @@ class MS_API Model {
///
/// \return The vector that includes all input tensors.
std::vector<MSTensor> GetInputs();
/// \brief Obtains the input tensor of the model by name.
///
/// \return The input tensor with the given name, if the name is not found, an invalid tensor is returned.
@ -90,15 +92,25 @@ class MS_API Model {
///
/// \return The vector that includes all output tensors.
std::vector<MSTensor> GetOutputs();
/// \brief Obtains names of all output tensors of the model.
///
/// \return A vector that includes names of all output tensors.
inline std::vector<std::string> GetOutputTensorNames();
/// \brief Obtains the output tensor of the model by name.
///
/// \return The output tensor with the given name, if the name is not found, an invalid tensor is returned.
inline MSTensor GetOutputByTensorName(const std::string &tensor_name);
inline std::vector<MSTensor> GetOutputsByNodeName(const std::string &tensor_name);
/// \brief Get output MSTensors of model by node name.
///
/// \param[in] node_name Define node name.
///
/// \note Deprecated, replace with GetOutputByTensorName
///
/// \return The vector of output MSTensor.
inline std::vector<MSTensor> GetOutputsByNodeName(const std::string &node_name);
/// \brief Inference model.
///
@ -112,9 +124,32 @@ class MS_API Model {
bool GetTrainMode() const;
Status Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
Status Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> cbs);
/// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_data Define the buffer read from a model file.
/// \param[in] size Define bytes number of model buffer.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
///
/// \return Status.
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_path Define the model path.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
///
/// \return Status.
Status Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
@ -140,8 +175,8 @@ MSTensor Model::GetOutputByTensorName(const std::string &tensor_name) {
return GetOutputByTensorName(StringToChar(tensor_name));
}
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::string &tensor_name) {
return GetOutputsByNodeName(StringToChar(tensor_name));
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::string &node_name) {
return GetOutputsByNodeName(StringToChar(node_name));
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H

View File

@ -52,17 +52,12 @@ enum ModelType : uint32_t {
kUnknownType = 0xFFFFFFFF
};
enum QuantizationType : uint32_t {
kNoQuant = 0,
kWeightQuant = 1,
kFullQuant = 2,
kUnknownQuantType = 0xFFFFFFFF
};
enum QuantizationType : uint32_t { kNoQuant = 0, kWeightQuant = 1, kFullQuant = 2, kUnknownQuantType = 0xFFFFFFFF };
enum OptimizationLevel : uint32_t {
kO0 = 0, // Do not change
kO2 = 2, // Cast network to float16, keep batchnorm and loss in float32,
kO3 = 3, // Cast network to float16, including bacthnorm
kO0 = 0, // Do not change
kO2 = 2, // Cast network to float16, keep batchnorm and loss in float32,
kO3 = 3, // Cast network to float16, including bacthnorm
kAuto = 4, // Choose optimization based on device
kOptimizationType = 0xFFFFFFFF
};
@ -90,6 +85,7 @@ class MS_API MSTensor {
/// \return A pointer of MSTensor.
static inline MSTensor *CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
/// \brief Creates a MSTensor object, whose data can be directly accessed by Model, must be used in pairs with
/// DestroyTensorPtr.
///
@ -102,6 +98,7 @@ class MS_API MSTensor {
/// \return A pointer of MSTensor.
static inline MSTensor *CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
/// \brief Creates a MSTensor object, whose device data can be directly accessed by Model, must be used in pairs with
/// DestroyTensorPtr.
///
@ -114,6 +111,7 @@ class MS_API MSTensor {
/// \return A pointer of MSTensor.
static inline MSTensor *CreateDevTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
/// \brief Create a string type MSTensor object whose data can be accessed by Model only after being copied, must be
/// used in pair with DestroyTensorPtr.
///
@ -122,12 +120,14 @@ class MS_API MSTensor {
///
/// \return A pointer of MSTensor.
static inline MSTensor *StringsToTensor(const std::string &name, const std::vector<std::string> &str);
/// \brief Parse the string type MSTensor object into strings.
///
/// \param[in] tensor A MSTensor object.
///
/// \return A vector container containing several strings.
static inline std::vector<std::string> TensorToStrings(const MSTensor &tensor);
/// \brief Destroy an object created by Clone, StringsToTensor, CreateRefTensor, CreateDevTensor or CreateTensor. Do
/// not use it to destroy MSTensor from other sources.
///
@ -145,14 +145,17 @@ class MS_API MSTensor {
///
/// \return The name of the MSTensor.
inline std::string Name() const;
/// \brief Obtains the data type of the MSTensor.
///
/// \return The data type of the MSTensor.
enum DataType DataType() const;
/// \brief Obtains the shape of the MSTensor.
///
/// \return The shape of the MSTensor.
const std::vector<int64_t> &Shape() const;
/// \brief Obtains the number of elements of the MSTensor.
///
/// \return The number of elements of the MSTensor.
@ -162,43 +165,95 @@ class MS_API MSTensor {
///
/// \return A shared pointer to the copy of data of the MSTensor.
std::shared_ptr<const void> Data() const;
/// \brief Obtains the pointer to the data of the MSTensor. If the MSTensor is a device tensor, the data cannot be
/// accessed directly on host.
///
/// \return A pointer to the data of the MSTensor.
void *MutableData();
/// \brief Obtains the length of the data of the MSTensor, in bytes.
///
/// \return The length of the data of the MSTensor, in bytes.
size_t DataSize() const;
/// \brief Gets the boolean value that indicates whether the memory of MSTensor is on device.
///
/// \return The boolean value that indicates whether the memory of MSTensor is on device.
bool IsDevice() const;
/// \brief Gets a deep copy of the MSTensor, must be used in pair with DestroyTensorPtr.
///
/// \return A pointer points to a deep copy of the MSTensor.
MSTensor *Clone() const;
/// \brief Gets the boolean value that indicates whether the MSTensor is valid.
///
/// \return The boolean value that indicates whether the MSTensor is valid.
bool operator==(std::nullptr_t) const;
/// \brief Gets the boolean value that indicates whether the MSTensor is valid.
///
/// \return The boolean value that indicates whether the MSTensor is valid.
bool operator!=(std::nullptr_t) const;
/// \brief Get the boolean value that indicates whether the MSTensor equals tensor.
///
/// \param[in] another MSTensor.
///
/// \return The boolean value that indicates whether the MSTensor equals tensor.
bool operator==(const MSTensor &tensor) const;
/// \brief Set the shape of for the MSTensor. Only valid for Lite.
///
/// \param[in] Shape of the MSTensor, a vector of int64_t.
void SetShape(const std::vector<int64_t> &shape);
/// \brief Set the data type for the MSTensor. Only valid for Lite.
///
/// \param[in] The data type of the MSTensor.
void SetDataType(enum DataType data_type);
/// \brief Set the name for the MSTensor. Only valid for Lite.
///
/// \param[in] The name of the MSTensor.
void SetTensorName(const std::string &name);
/// \brief Set the Allocator for the MSTensor. Only valid for Lite.
///
/// \param[in] A pointer to Allocator.
void SetAllocator(std::shared_ptr<Allocator> allocator);
/// \brief Obtain the Allocator of the MSTensor. Only valid for Lite.
///
/// \return A pointer to Allocator.
std::shared_ptr<Allocator> allocator() const;
/// \brief Set the format for the MSTensor. Only valid for Lite.
///
/// \param[in] The format of the MSTensor.
void SetFormat(mindspore::Format format);
/// \brief Obtain the format of the MSTensor. Only valid for Lite.
///
/// \return The format of the MSTensor.
mindspore::Format format() const;
/// \brief Set the data for the MSTensor. Only valid for Lite.
///
/// \param[in] A pointer to the data of the MSTensor.
void SetData(void *data);
/// \brief Get the quantization parameters of the MSTensor. Only valid for Lite.
///
/// \return The quantization parameters of the MSTensor.
std::vector<QuantParam> QuantParams() const;
void SetQuantParams(std::vector<QuantParam> quant_args);
/// \brief Set the quantization parameters for the MSTensor. Only valid for Lite.
///
/// \param[in] The quantization parameters of the MSTensor.
void SetQuantParams(std::vector<QuantParam> quant_params);
const std::shared_ptr<Impl> impl() const { return impl_; }
private:
@ -276,12 +331,13 @@ using Key = struct Key {
Key() : len(0) {}
explicit Key(const char *dec_key, size_t key_len);
};
constexpr char kDecModeAesGcm[] = "AES-GCM";
/// \brief CallBackParam defined input arguments for callBack function.
struct MSCallBackParam {
std::string node_name_; /**< node name argument */
std::string node_type_; /**< node type argument */
std::string node_name; /**< node name argument */
std::string node_type; /**< node type argument */
};
/// \brief KernelCallBack defined the function pointer for callBack.

View File

@ -226,13 +226,13 @@ build_lite() {
compile_nnie_script=${BASEPATH}/mindspore/lite/tools/providers/NNIE/Hi3516D/compile_nnie.sh
cd ${BASEPATH}/../
if [[ "${local_lite_platform}" == "x86_64" ]]; then
sh ${compile_nnie_script} -I x86_64 -b nnie_master -j $THREAD_NUM
sh ${compile_nnie_script} -I x86_64 -b nnie_master_tmp -j $THREAD_NUM
if [[ $? -ne 0 ]]; then
echo "compile x86_64 for nnie failed."
exit 1
fi
elif [[ "${local_lite_platform}" == "arm32" ]]; then
sh ${compile_nnie_script} -I arm32 -b nnie_master -j $THREAD_NUM
sh ${compile_nnie_script} -I arm32 -b nnie_master_tmp -j $THREAD_NUM
if [[ $? -ne 0 ]]; then
echo "compile arm32 for nnie failed."
exit 1

View File

@ -162,8 +162,8 @@ Status ModelImpl::RunGraph(const MSKernelCallBack &before, const MSKernelCallBac
std::vector<MSTensor> inputs = LiteTensorsToMSTensors(before_inputs);
std::vector<MSTensor> outputs = LiteTensorsToMSTensors(before_outputs);
MSCallBackParam mscall_param;
mscall_param.node_name_ = call_param.node_name;
mscall_param.node_type_ = call_param.node_type;
mscall_param.node_name = call_param.node_name;
mscall_param.node_type = call_param.node_type;
return before(inputs, outputs, mscall_param);
};
@ -173,8 +173,8 @@ Status ModelImpl::RunGraph(const MSKernelCallBack &before, const MSKernelCallBac
std::vector<MSTensor> inputs = LiteTensorsToMSTensors(before_inputs);
std::vector<MSTensor> outputs = LiteTensorsToMSTensors(before_outputs);
MSCallBackParam mscall_param;
mscall_param.node_name_ = call_param.node_name;
mscall_param.node_type_ = call_param.node_type;
mscall_param.node_name = call_param.node_name;
mscall_param.node_type = call_param.node_type;
return after(inputs, outputs, mscall_param);
};
auto ret = session_->RunGraph(before_call_back, after_call_back);

View File

@ -72,24 +72,24 @@ NPUDelegate::~NPUDelegate() {
}
}
int NPUDelegate::Init() {
Status NPUDelegate::Init() {
npu_manager_ = new (std::nothrow) NPUManager(frequency_);
if (npu_manager_ == nullptr) {
MS_LOG(ERROR) << "New npu manager failed.";
return RET_ERROR;
return mindspore::kLiteNullptr;
}
if (!npu_manager_->IsSupportNPU()) {
MS_LOG(DEBUG) << "Checking that npu is unsupported.";
free(npu_manager_);
npu_manager_ = nullptr;
return RET_NOT_SUPPORT;
return mindspore::kLiteNotSupport;
}
pass_manager_ = new (std::nothrow) NPUPassManager();
if (pass_manager_ == nullptr) {
free(npu_manager_);
npu_manager_ = nullptr;
MS_LOG(ERROR) << "New npu pass manager failed.";
return RET_ERROR;
return mindspore::kLiteNullptr;
}
auto transform_pass = new (std::nothrow) NPUTransformPass();
pass_manager_->AddPass(transform_pass);
@ -157,10 +157,10 @@ int NPUDelegate::Init() {
{schema::PrimitiveType_Transpose, GetNPUOp<TransposeNPUOp>},
{schema::PrimitiveType_Unsqueeze, GetNPUOp<UnsqueezeNPUOp>},
};
return RET_OK;
return mindspore::kSuccess;
}
int NPUDelegate::Build(DelegateModel *model) {
Status NPUDelegate::Build(DelegateModel *model) {
KernelIter from, end;
std::vector<NPUOp *> npu_ops;
int graph_index = 0;
@ -179,7 +179,7 @@ int NPUDelegate::Build(DelegateModel *model) {
auto npu_graph_kernel = CreateNPUGraph(npu_ops, model, from, end);
if (npu_graph_kernel == nullptr) {
MS_LOG(ERROR) << "Create NPU Graph failed.";
return RET_ERROR;
return mindspore::kLiteNullptr;
}
npu_graph_kernel->set_name("NpuGraph" + std::to_string(graph_index++));
iter = model->Replace(from, end + 1, npu_graph_kernel);
@ -191,7 +191,7 @@ int NPUDelegate::Build(DelegateModel *model) {
auto npu_graph_kernel = CreateNPUGraph(npu_ops, model, from, end);
if (npu_graph_kernel == nullptr) {
MS_LOG(ERROR) << "Create NPU Graph failed.";
return RET_ERROR;
return mindspore::kLiteNullptr;
}
npu_graph_kernel->set_name("NpuGraph" + std::to_string(graph_index++));
model->Replace(from, end + 1, npu_graph_kernel);
@ -200,9 +200,9 @@ int NPUDelegate::Build(DelegateModel *model) {
auto ret = npu_manager_->LoadOMModel();
if (ret != RET_OK) {
MS_LOG(ERROR) << "NPU client load model failed.";
return RET_ERROR;
return mindspore::kLiteError;
}
return RET_OK;
return mindspore::kSuccess;
}
NPUOp *NPUDelegate::GetOP(kernel::Kernel *kernel, const schema::Primitive *primitive) {

View File

@ -32,9 +32,9 @@ class NPUDelegate : public Delegate {
~NPUDelegate() override;
int Init() override;
Status Init() override;
int Build(DelegateModel *model) override;
Status Build(DelegateModel *model) override;
protected:
NPUOp *GetOP(kernel::Kernel *kernel, const schema::Primitive *primitive);

View File

@ -48,9 +48,9 @@ bool IsHardwareSupport() {
return true;
}
int TensorRTDelegate::Init() {
Status TensorRTDelegate::Init() {
if (!IsHardwareSupport()) {
return RET_NOT_SUPPORT;
return mindspore::kLiteNotSupport;
}
std::vector<std::shared_ptr<DeviceInfoContext>> device_list = context_->MutableDeviceInfo();
auto iter = std::find_if(device_list.begin(), device_list.end(), [](std::shared_ptr<DeviceInfoContext> device) {
@ -58,12 +58,12 @@ int TensorRTDelegate::Init() {
});
if (iter == device_list.end()) {
MS_LOG(ERROR) << "no gpu device info found for TensorRT.";
return RET_ERROR;
return mindspore::kLiteError;
}
auto gpu_info = (*iter)->Cast<GPUDeviceInfo>();
if (gpu_info == nullptr) {
MS_LOG(ERROR) << "no gpu device info found for TensorRT.";
return RET_ERROR;
return mindspore::kLiteError;
}
device_info_ = gpu_info;
op_func_lists_.clear();
@ -93,10 +93,10 @@ int TensorRTDelegate::Init() {
{schema::PrimitiveType_Flatten, GetTensorRTOp<ShuffleTensorRT>},
{schema::PrimitiveType_Sqrt, GetTensorRTOp<UnaryTensorRT>},
};
return RET_OK;
return mindspore::kSuccess;
}
int TensorRTDelegate::Build(DelegateModel *model) {
Status TensorRTDelegate::Build(DelegateModel *model) {
KernelIter from, end;
std::vector<TensorRTOp *> tensorrt_ops;
int graph_index = 0;
@ -115,7 +115,7 @@ int TensorRTDelegate::Build(DelegateModel *model) {
auto tensorrt_subgraph = CreateTensorRTGraph(tensorrt_ops, model, from, end);
if (tensorrt_subgraph == nullptr) {
MS_LOG(ERROR) << "Create TensorRT Graph failed.";
return RET_ERROR;
return mindspore::kLiteNullptr;
}
tensorrt_subgraph->set_name("TensorRtGraph" + std::to_string(graph_index++));
iter = model->Replace(from, end + 1, tensorrt_subgraph);
@ -127,13 +127,13 @@ int TensorRTDelegate::Build(DelegateModel *model) {
auto tensorrt_subgraph = CreateTensorRTGraph(tensorrt_ops, model, from, end);
if (tensorrt_subgraph == nullptr) {
MS_LOG(DEBUG) << "Create TensorRT Graph failed.";
return RET_ERROR;
return mindspore::kLiteNullptr;
}
tensorrt_subgraph->set_name("TensorRtGraph" + std::to_string(graph_index++));
model->Replace(from, end + 1, tensorrt_subgraph);
tensorrt_ops.clear();
}
return RET_OK;
return mindspore::kSuccess;
}
TensorRTOp *TensorRTDelegate::FindTensorRTOp(kernel::Kernel *kernel, const schema::Primitive *primitive) {

View File

@ -37,9 +37,9 @@ class TensorRTDelegate : public Delegate {
~TensorRTDelegate() override = default;
int Init() override;
Status Init() override;
int Build(DelegateModel *model) override;
Status Build(DelegateModel *model) override;
private:
TensorRTOp *FindTensorRTOp(kernel::Kernel *kernel, const schema::Primitive *primitive);

View File

@ -49,9 +49,7 @@ struct KernelKey {
int type = 0;
std::string kernel_arch;
std::string provider{kBuiltin};
#ifndef DELEGATE_CLIP
std::shared_ptr<Delegate> delegate = nullptr;
#endif
bool operator<(const KernelKey &dst) const {
if (provider != dst.provider) {
return provider < dst.provider;

View File

@ -199,7 +199,7 @@ int LiteKernelUtil::SetInput(const LiteKernel &kernelMod, const std::vector<lite
#ifndef CONTROLFLOW_TENSORLIST_CLIP
bool LiteKernelUtil::IsSwitchCall(kernel::LiteKernel *kernel) {
#ifndef DELEGATE_CLIP
if (kernel->desc().delegate != nullptr) {
if (kernel->desc().arch == kernel::kDelegate) {
return false;
}
#endif

View File

@ -78,7 +78,7 @@ void LiteOpActor::ReplaceNodeInTensor(kernel::LiteKernel *kernel, Tensor *old_te
int ref_count = 0;
#ifndef DELEGATE_CLIP
/* set op input for calculate */
if (kernel->desc().delegate != nullptr) {
if (kernel->desc().arch == kernel::kDelegate) {
ref_count++;
} else {
#endif
@ -202,7 +202,7 @@ int LiteOpActor::CompileArrowThroughOutputKernels() {
#ifndef CONTROLFLOW_TENSORLIST_CLIP
int LiteOpActor::CompileArrowThroughPartialCall() {
#ifndef DELEGATE_CLIP
if (kernel_->desc().delegate != nullptr) {
if (kernel_->desc().arch == kernel::kDelegate) {
MS_LOG(INFO) << "kernel is delegate subgraph kernel.";
return RET_OK;
}

View File

@ -420,7 +420,7 @@ void LiteSession::IsolateOutputTensor() {
}
}
#ifndef DELEGATE_CLIP
if (subgraph->desc().delegate != nullptr) {
if (subgraph->desc().arch == kernel::kDelegate) {
continue;
}
#endif
@ -581,7 +581,7 @@ int LiteSession::PrepareKernels(Model *model, bool use_mindrt_run) {
for (auto kernel : this->kernels_) {
kernel->FindInoutKernels(this->kernels_);
#ifndef DELEGATE_CLIP
if (kernel->desc().delegate != nullptr) {
if (kernel->desc().arch == kernel::kDelegate) {
all_kernels.push_back(kernel);
} else {
#endif
@ -604,7 +604,7 @@ int LiteSession::PrepareKernels(Model *model, bool use_mindrt_run) {
// init init_ref_count for subgraphs and kernels
for (auto *kernel : this->kernels_) {
#ifndef DELEGATE_CLIP
if (kernel->desc().delegate != nullptr) {
if (kernel->desc().arch == kernel::kDelegate) {
continue;
}
#endif
@ -699,11 +699,13 @@ int LiteSession::Init(InnerContext *context) {
#ifndef DELEGATE_CLIP
if (delegate_ != nullptr) {
auto delegate_ret = delegate_->Init();
if (delegate_ret == RET_NOT_SUPPORT) {
if (delegate_ret == mindspore::kLiteNotSupport) {
MS_LOG(DEBUG) << "Delegate is unsupported";
delegate_.reset();
delegate_ = nullptr;
}
if (delegate_ret == RET_ERROR) {
} else if (delegate_ret == mindspore::kSuccess) {
MS_LOG(INFO) << "Delegate init successfully";
} else {
MS_LOG(ERROR) << "Delegate init failed";
return RET_ERROR;
}
@ -855,7 +857,7 @@ int LiteSession::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels)
}
auto ret = RET_OK;
#ifndef DELEGATE_CLIP
if (kernel->desc().delegate != nullptr) {
if (kernel->desc().arch == kernel::kDelegate) {
ret = kernel->ReSize();
} else {
#endif

View File

@ -80,7 +80,7 @@ int Scheduler::InitKernels(std::vector<kernel::LiteKernel *> dst_kernels) {
for (auto kernel : dst_kernels) {
#ifndef DELEGATE_CLIP
// delegate graph kernel
if (kernel->desc().delegate != nullptr) {
if (kernel->desc().arch == kernel::kDelegate) {
continue;
}
#endif
@ -230,10 +230,10 @@ int Scheduler::ReplaceDelegateKernels(std::vector<kernel::LiteKernel *> *dst_ker
return RET_NULL_PTR;
}
auto ret = delegate_->Build(model);
if (ret != RET_OK) {
if (ret != mindspore::kSuccess) {
delete model;
MS_LOG(ERROR) << "Delegate prepare kernels failed.";
return ret;
return RET_ERROR;
}
auto src_kernels = *dst_kernels;
@ -268,7 +268,7 @@ int Scheduler::ReplaceDelegateKernels(std::vector<kernel::LiteKernel *> *dst_ker
break;
}
}
kernel::KernelKey delegate_desc{kernel::kDelegate, delegate_type, schema::PrimitiveType_NONE, "", "", delegate_};
kernel::KernelKey delegate_desc{kernel::kDelegate, delegate_type, schema::PrimitiveType_NONE, "", ""};
lite_kernel->set_desc(delegate_desc);
dst_kernels->push_back(lite_kernel);
}
@ -1248,7 +1248,7 @@ kernel::LiteKernel *FindAllSubGraphKernels(const std::vector<kernel::LiteKernel
MS_ASSERT(GetKernelSubGraphType(cur_kernel, context) != kernel::kApuSubGraph);
// already a subgraph or a delegate
#ifndef DELEGATE_CLIP
if (cur_kernel->desc().delegate != nullptr) {
if (cur_kernel->desc().arch == kernel::kDelegate) {
--(*cur_index);
break;
}
@ -1278,7 +1278,7 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> src_kernel,
// Not support APU now
MS_ASSERT(GetKernelSubGraphType(cur_kernel, *context_) != kernel::kApuSubGraph);
#ifndef DELEGATE_CLIP
if (cur_kernel->desc().delegate != nullptr) {
if (cur_kernel->desc().arch == kernel::kDelegate) {
dst_kernel->emplace_back(cur_kernel);
continue;
}
@ -1297,8 +1297,7 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> src_kernel,
}
for (auto *subgraph : *dst_kernel) {
#ifndef DELEGATE_CLIP
auto subgraph_delegate = subgraph->desc().delegate;
if (subgraph_delegate == nullptr) {
if (subgraph->desc().arch != kernel::kDelegate) {
#endif
auto ret = subgraph->Init();
if (ret != RET_OK) {

View File

@ -340,6 +340,11 @@ void *Tensor::MutableData() {
return this->data_;
}
void *Tensor::data() {
Prepare();
return this->data_;
}
void Tensor::IncRefCount() {
ref_count_++;
if (allocator_ != nullptr) {

View File

@ -120,7 +120,7 @@ class Tensor : public mindspore::tensor::MSTensor {
void *ReallocData();
void *data() override { return this->data_; }
void *data() override;
virtual void *data_c() const { return data_; }

View File

@ -512,11 +512,11 @@ int BenchmarkUnifiedApi::InitTimeProfilingCallbackParameter() {
if (before_outputs.empty()) {
MS_LOG(INFO) << "The num of beforeOutputs is empty";
}
if (op_times_by_type_.find(call_param.node_type_) == op_times_by_type_.end()) {
op_times_by_type_.insert(std::make_pair(call_param.node_type_, std::make_pair(0, 0.0f)));
if (op_times_by_type_.find(call_param.node_type) == op_times_by_type_.end()) {
op_times_by_type_.insert(std::make_pair(call_param.node_type, std::make_pair(0, 0.0f)));
}
if (op_times_by_name_.find(call_param.node_name_) == op_times_by_name_.end()) {
op_times_by_name_.insert(std::make_pair(call_param.node_name_, std::make_pair(0, 0.0f)));
if (op_times_by_name_.find(call_param.node_name) == op_times_by_name_.end()) {
op_times_by_name_.insert(std::make_pair(call_param.node_name, std::make_pair(0, 0.0f)));
}
op_call_times_total_++;
@ -542,10 +542,10 @@ int BenchmarkUnifiedApi::InitTimeProfilingCallbackParameter() {
cost = static_cast<float>(gpu_param.execute_time);
}
op_cost_total_ += cost;
op_times_by_type_[call_param.node_type_].first++;
op_times_by_type_[call_param.node_type_].second += cost;
op_times_by_name_[call_param.node_name_].first++;
op_times_by_name_[call_param.node_name_].second += cost;
op_times_by_type_[call_param.node_type].first++;
op_times_by_type_[call_param.node_type].second += cost;
op_times_by_name_[call_param.node_name].first++;
op_times_by_name_[call_param.node_name].second += cost;
return true;
};
return RET_OK;
@ -604,11 +604,11 @@ int BenchmarkUnifiedApi::InitPerfProfilingCallbackParameter() {
if (before_outputs.empty()) {
MS_LOG(INFO) << "The num of beforeOutputs is empty";
}
if (op_perf_by_type_.find(call_param.node_type_) == op_perf_by_type_.end()) {
op_perf_by_type_.insert(std::make_pair(call_param.node_type_, std::make_pair(0, zero)));
if (op_perf_by_type_.find(call_param.node_type) == op_perf_by_type_.end()) {
op_perf_by_type_.insert(std::make_pair(call_param.node_type, std::make_pair(0, zero)));
}
if (op_perf_by_name_.find(call_param.node_name_) == op_perf_by_name_.end()) {
op_perf_by_name_.insert(std::make_pair(call_param.node_name_, std::make_pair(0, zero)));
if (op_perf_by_name_.find(call_param.node_name) == op_perf_by_name_.end()) {
op_perf_by_name_.insert(std::make_pair(call_param.node_name, std::make_pair(0, zero)));
}
op_call_times_total_++;
@ -637,12 +637,12 @@ int BenchmarkUnifiedApi::InitPerfProfilingCallbackParameter() {
float cost2 = static_cast<float>(res.values[1].value);
op_cost_total_ += cost1;
op_cost2_total_ += cost2;
op_perf_by_type_[call_param.node_type_].first++;
op_perf_by_type_[call_param.node_type_].second.value[0] += cost1;
op_perf_by_type_[call_param.node_type_].second.value[1] += cost2;
op_perf_by_name_[call_param.node_name_].first++;
op_perf_by_name_[call_param.node_name_].second.value[0] += cost1;
op_perf_by_name_[call_param.node_name_].second.value[1] += cost2;
op_perf_by_type_[call_param.node_type].first++;
op_perf_by_type_[call_param.node_type].second.value[0] += cost1;
op_perf_by_type_[call_param.node_type].second.value[1] += cost2;
op_perf_by_name_[call_param.node_name].first++;
op_perf_by_name_[call_param.node_name].second.value[0] += cost1;
op_perf_by_name_[call_param.node_name].second.value[1] += cost2;
return true;
};
#endif
@ -728,12 +728,12 @@ int BenchmarkUnifiedApi::InitPrintTensorDataCallbackParameter() {
ms_after_call_back_ = [&](const std::vector<mindspore::MSTensor> &after_inputs,
const std::vector<mindspore::MSTensor> &after_outputs, const MSCallBackParam &call_param) {
std::cout << "================================================================" << std::endl;
std::cout << call_param.node_name_ << " inputs : " << std::endl;
std::cout << call_param.node_name << " inputs : " << std::endl;
for (auto ms_tensor : after_inputs) {
std::cout << DumpMSTensor(&ms_tensor) << std::endl;
}
std::cout << "----------------------------------------------------------------" << std::endl;
std::cout << call_param.node_name_ << " outputs : " << std::endl;
std::cout << call_param.node_name << " outputs : " << std::endl;
for (auto ms_tensor : after_outputs) {
std::cout << DumpMSTensor(&ms_tensor) << std::endl;
}
@ -750,11 +750,11 @@ int BenchmarkUnifiedApi::InitDumpTensorDataCallbackParameter() {
auto dump_mode = dump_cfg_json_[dump::kSettings][dump::kMode].get<int>();
auto input_output_mode = dump_cfg_json_[dump::kSettings][dump::kInputOutput].get<int>();
auto kernels = dump_cfg_json_[dump::kSettings][dump::kKernels].get<std::vector<std::string>>();
if (dump_mode == 0 || std::find(kernels.begin(), kernels.end(), call_param.node_name_) != kernels.end()) {
if (dump_mode == 0 || std::find(kernels.begin(), kernels.end(), call_param.node_name) != kernels.end()) {
if (input_output_mode == 0 || input_output_mode == 1) {
for (size_t i = 0; i < before_inputs.size(); i++) {
auto ms_tensor = before_inputs.at(i);
auto file_name = GenerateOutputFileName(&ms_tensor, call_param.node_name_, "input", i);
auto file_name = GenerateOutputFileName(&ms_tensor, call_param.node_name, "input", i);
auto abs_file_path = dump_file_output_dir_ + "/" + file_name;
if (WriteToBin(abs_file_path, ms_tensor.MutableData(), ms_tensor.DataSize()) != RET_OK) { // save to file
MS_LOG(ERROR) << "write tensor data to file failed.";
@ -773,11 +773,11 @@ int BenchmarkUnifiedApi::InitDumpTensorDataCallbackParameter() {
auto input_output_mode = dump_cfg_json_[dump::kSettings][dump::kInputOutput].get<int>();
auto kernels = dump_cfg_json_[dump::kSettings][dump::kKernels].get<std::vector<std::string>>();
if (dump_mode == kDumpInputsAndOutputs ||
std::find(kernels.begin(), kernels.end(), call_param.node_name_) != kernels.end()) {
std::find(kernels.begin(), kernels.end(), call_param.node_name) != kernels.end()) {
if (input_output_mode == kDumpInputsAndOutputs || input_output_mode == kDumpOutputs) {
for (size_t i = 0; i < after_outputs.size(); i++) {
auto ms_tensor = after_outputs.at(i);
auto file_name = GenerateOutputFileName(&ms_tensor, call_param.node_name_, "output", i);
auto file_name = GenerateOutputFileName(&ms_tensor, call_param.node_name, "output", i);
auto abs_file_path = dump_file_output_dir_ + "/" + file_name;
if (WriteToBin(abs_file_path, ms_tensor.MutableData(), ms_tensor.DataSize()) != RET_OK) { // save to file
MS_LOG(ERROR) << "write tensor data to file failed.";