micro support train

This commit is contained in:
jianghui58 2022-08-04 19:35:27 +08:00
parent 2b35007b0c
commit bc531acbb8
45 changed files with 1349 additions and 475 deletions

View File

@ -102,7 +102,7 @@ class MS_API Model {
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode,
const std::string &cropto_lib_path);
/// \brief Builds a model
/// \brief Build a model
///
/// \param[in] graph GraphCell is a derivative of Cell. Cell is not available currently. GraphCell can be constructed
/// from Graph, for example, model.Build(GraphCell(graph), context).
@ -124,7 +124,7 @@ class MS_API Model {
Status Build(GraphCell graph, Node *optimizer, std::vector<Expr *> inputs,
const std::shared_ptr<Context> &model_context, const std::shared_ptr<TrainCfg> &train_cfg);
/// \brief Builds a Transfer Learning model where the backbone weights are fixed and the head weights are trainable
/// \brief Build a Transfer Learning model where the backbone weights are fixed and the head weights are trainable
///
/// \param[in] backbone The static, non-learnable part of the graph
/// \param[in] head The trainable part of the graph
@ -135,7 +135,7 @@ class MS_API Model {
Status BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr<Context> &context,
const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
/// \brief Resizes the shapes of inputs.
/// \brief Resize the shapes of inputs.
///
/// \param[in] inputs A vector that includes all input tensors in order.
/// \param[in] dims Defines the new shapes of inputs, should be consistent with inputs.
@ -170,10 +170,10 @@ class MS_API Model {
/// \return Status.
Status Predict(const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
/// \brief Train model by step.
/// \brief Run model by step.
///
/// \param[in] before CallBack before predict.
/// \param[in] after CallBack after predict.
/// \param[in] before CallBack before RunStep.
/// \param[in] after CallBack after RunStep.
///
/// \return Status.
Status RunStep(const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
@ -226,36 +226,36 @@ class MS_API Model {
/// \return The input tensor with the given name, if the name is not found, an invalid tensor is returned.
inline MSTensor GetInputByTensorName(const std::string &tensor_name);
/// \brief Obtains all gradient tensors of the model.
/// \brief Obtain all gradient tensors of the model.
///
/// \return The vector that includes all gradient tensors.
std::vector<MSTensor> GetGradients() const;
/// \brief update gradient tensors of the model.
/// \brief Update gradient tensors of the model.
///
/// \param[in] gradients A vector new gradients.
///
/// \return Status of operation
Status ApplyGradients(const std::vector<MSTensor> &gradients);
/// \brief Obtains all weights tensors of the model.
/// \brief Obtain all weights tensors of the model.
///
/// \return The vector that includes all gradient tensors.
std::vector<MSTensor> GetFeatureMaps() const;
/// \brief update weights tensors of the model.
/// \brief Update weights tensors of the model.
///
/// \param[in] new_weights A vector new weights.
///
/// \return Status of operation
Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
/// \brief Obtains optimizer params tensors of the model.
/// \brief Obtain optimizer params tensors of the model.
///
/// \return The vector that includes all params tensors.
std::vector<MSTensor> GetOptimizerParams() const;
/// \brief update the optimizer parameters.
/// \brief Update the optimizer parameters.
///
/// \param[in] params A vector new optimizer params.
///
@ -271,14 +271,14 @@ class MS_API Model {
/// \return Status of operation.
Status SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f);
/// \brief Sets the Learning Rate of the training.
/// \brief Set the Learning Rate of the training.
///
/// \param[in] learning_rate to set.
///
/// \return Status of operation.
Status SetLearningRate(float learning_rate);
/// \brief Gets the Learning Rate of the optimizer.
/// \brief Get the Learning Rate of the optimizer.
///
/// \return Learning rate. 0.0 if no optimizer was found.
float GetLearningRate();

View File

@ -85,7 +85,7 @@ MS_API MSStatus MSModelBuild(MSModelHandle model, const void *model_data, size_t
MS_API MSStatus MSModelBuildFromFile(MSModelHandle model, const char *model_path, MSModelType model_type,
const MSContextHandle model_context);
/// \brief Resizes the shapes of inputs.
/// \brief Resize the shapes of inputs.
///
/// \param[in] model Model object handle.
/// \param[in] inputs The array that includes all input tensor handles.
@ -108,21 +108,46 @@ MS_API MSStatus MSModelResize(MSModelHandle model, const MSTensorHandleArray inp
MS_API MSStatus MSModelPredict(MSModelHandle model, const MSTensorHandleArray inputs, MSTensorHandleArray *outputs,
const MSKernelCallBackC before, const MSKernelCallBackC after);
/// \brief Obtains all input tensor handles of the model.
/// \brief Run model by step. Only valid for Iot.
///
/// \param[in] model Model object handle.
/// \param[in] before CallBack before RunStep.
/// \param[in] after CallBack after RunStep.
///
/// \return MSStatus.
MS_API MSStatus MSModelRunStep(MSModelHandle model, const MSKernelCallBackC before, const MSKernelCallBackC after);
/// \brief Set the model running mode. Only valid for Iot.
///
/// \param[in] model Model object handle.
/// \param[in] train True means model runs in Train Mode, otherwise Eval Mode.
///
/// \return Status of operation.
MS_API MSStatus MSModelSetTrainMode(const MSModelHandle model, bool train);
/// \brief Export the weights of model to the binary file. Only valid for Iot.
///
/// \param[in] model Model object handle.
/// \param[in] export_path Define the export weight file path.
///
/// \return Status of operation.
MS_API MSStatus MSModelExportWeight(const MSModelHandle model, const char *export_path);
/// \brief Obtain all input tensor handles of the model.
///
/// \param[in] model Model object handle.
///
/// \return The array that includes all input tensor handles.
MS_API MSTensorHandleArray MSModelGetInputs(const MSModelHandle model);
/// \brief Obtains all output tensor handles of the model.
/// \brief Obtain all output tensor handles of the model.
///
/// \param[in] model Model object handle.
///
/// \return The array that includes all output tensor handles.
MS_API MSTensorHandleArray MSModelGetOutputs(const MSModelHandle model);
/// \brief Obtains the input tensor handle of the model by name.
/// \brief Obtain the input tensor handle of the model by name.
///
/// \param[in] model Model object handle.
/// \param[in] tensor_name The name of tensor.
@ -130,7 +155,7 @@ MS_API MSTensorHandleArray MSModelGetOutputs(const MSModelHandle model);
/// \return The input tensor handle with the given name, if the name is not found, an NULL is returned.
MS_API MSTensorHandle MSModelGetInputByTensorName(const MSModelHandle model, const char *tensor_name);
/// \brief Obtains the output tensor handle of the model by name.
/// \brief Obtain the output tensor handle of the model by name.
///
/// \param[in] model Model object handle.
/// \param[in] tensor_name The name of tensor.

View File

@ -19,7 +19,6 @@
#include "nnacl/op_base.h"
#include "schema/model_generated.h"
#include "src/common/log_adapter.h"
#include "src/common/version_manager.h"
namespace mindspore {
namespace lite {

View File

@ -17,9 +17,11 @@
#ifndef MINDSPORE_LITE_SRC_COMMON_PRIM_UTIL_H_
#define MINDSPORE_LITE_SRC_COMMON_PRIM_UTIL_H_
#include "src/common/version_manager.h"
namespace mindspore {
namespace lite {
int GetPrimitiveType(const void *primitive, int schema_version);
int GetPrimitiveType(const void *primitive, int schema_version = SCHEMA_CUR);
const char *GetPrimitiveTypeName(const void *primitive, int schema_version);
const char *PrimitiveCurVersionTypeName(int type);
int GenPrimVersionKey(int primitive_type, int schema_version);

View File

@ -383,6 +383,21 @@ MSStatus MSModelPredict(MSModelHandle model, const MSTensorHandleArray inputs, M
return static_cast<MSStatus>(ret.StatusCode());
}
MSStatus MSModelRunStep(MSModelHandle model, const MSKernelCallBackC before, const MSKernelCallBackC after) {
MS_LOG(ERROR) << "Unsupported Feature.";
return kMSStatusLiteNotSupport;
}
MSStatus MSModelSetTrainMode(const MSModelHandle model, bool train) {
MS_LOG(ERROR) << "Unsupported Feature.";
return kMSStatusLiteNotSupport;
}
MSStatus MSModelExportWeight(const MSModelHandle model, const char *export_path) {
MS_LOG(ERROR) << "Unsupported Feature.";
return kMSStatusLiteNotSupport;
}
MSTensorHandleArray MSModelGetInputs(const MSModelHandle model) {
if (model == nullptr) {
MS_LOG(ERROR) << "param is nullptr.";

View File

@ -315,7 +315,7 @@ int TrainSession::AllocTensors(const std::vector<kernel::KernelExec *> &kernels)
for (auto tensor : kernel->out_tensors()) {
auto it = offset_map.find(tensor);
if (it != offset_map.end()) {
tensor->set_data(reinterpret_cast<void *>(reinterpret_cast<char *>(tensors_data_) + it->second));
tensor->set_data(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensors_data_) + it->second));
}
}
}
@ -764,7 +764,7 @@ void TrainSession::CompileTrainKernels() {
}
}
std::unordered_map<kernel::KernelExec *, int> map;
while (queue.size()) {
while (!queue.empty()) {
// pop first element
auto k = queue.front();
train_kernels_.push_back(k);

View File

@ -245,13 +245,11 @@ int NetTrain::MarkPerformance() {
uint64_t time_min = 0xFFFFFFFFFFFFFFFF;
uint64_t time_max = 0;
uint64_t time_avg = 0;
std::vector<MSTensor> outputs;
for (int i = 0; i < flags_->epochs_; i++) {
auto start = GetTimeUs();
auto status = flags_->time_profiling_
? ms_model_.Predict(ms_inputs_for_api_, &outputs, before_call_back_, after_call_back_)
: ms_model_.Predict(ms_inputs_for_api_, &outputs);
auto status =
flags_->time_profiling_ ? ms_model_.RunStep(before_call_back_, after_call_back_) : ms_model_.RunStep();
if (status != mindspore::kSuccess) {
MS_LOG(ERROR) << "Inference error " << status;
std::cerr << "Inference error " << status;
@ -299,8 +297,7 @@ int NetTrain::MarkAccuracy(bool enforce_accuracy) {
return RET_ERROR;
}
}
std::vector<MSTensor> outputs;
auto status = ms_model_.Predict(ms_inputs_for_api_, &outputs);
auto status = ms_model_.RunStep();
if (status != mindspore::kSuccess) {
MS_LOG(ERROR) << "Inference error " << status;
std::cerr << "Inference error " << status << std::endl;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,7 +22,7 @@
namespace mindspore {
namespace lite {
int EraseBlankSpaceAndLineBreak(std::string *input_string) {
bool EraseBlankSpaceAndLineBreak(std::string *input_string) {
if (input_string == nullptr) {
MS_LOG(ERROR) << "input_string is nullptr";
return false;
@ -33,7 +33,7 @@ int EraseBlankSpaceAndLineBreak(std::string *input_string) {
return true;
}
int EraseQuotes(std::string *input_string) {
bool EraseQuotes(std::string *input_string) {
if (input_string == nullptr) {
MS_LOG(ERROR) << "input_string is nullptr";
return false;
@ -49,6 +49,19 @@ int EraseQuotes(std::string *input_string) {
return true;
}
bool FindAndReplaceAll(std::string *input_str, const std::string &search, const std::string &replace) {
if (input_str == nullptr) {
MS_LOG(ERROR) << "input_str is nullptr";
return false;
}
auto pos = input_str->find(search);
while (pos != std::string::npos) {
input_str->replace(pos, search.size(), replace);
pos = input_str->find(search, pos + replace.size());
}
return true;
}
std::vector<std::string> SplitStringToVector(const std::string &raw_str, const char &delimiter) {
if (raw_str.empty()) {
MS_LOG(ERROR) << "input string is empty.";

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,9 +26,11 @@
namespace mindspore {
namespace lite {
int EraseBlankSpaceAndLineBreak(std::string *input_string);
bool EraseBlankSpaceAndLineBreak(std::string *input_string);
int EraseQuotes(std::string *input_string);
bool EraseQuotes(std::string *input_string);
bool FindAndReplaceAll(std::string *input_str, const std::string &search, const std::string &replace);
std::vector<std::string> SplitStringToVector(const std::string &raw_str, const char &delimiter);

View File

@ -62,7 +62,9 @@ STATUS IsolateDropoutNode(schema::MetaGraphT *graphT, size_t nodeIdx) {
auto matchedTensor =
std::find_if(gOutTensorIdx.begin(), gOutTensorIdx.end(),
[&outDataTensorIdx](const unsigned int &idx) { return (idx == outDataTensorIdx); });
if (matchedTensor != gOutTensorIdx.end()) {
*matchedTensor = inDataTensorIdx;
}
// find poseNode
auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
for (auto postNodeIdx : postNodeIdxes) {
@ -71,9 +73,11 @@ STATUS IsolateDropoutNode(schema::MetaGraphT *graphT, size_t nodeIdx) {
MS_ASSERT(postNode != nullptr);
auto iter = std::find_if(postNode->inputIndex.begin(), postNode->inputIndex.end(),
[&outDataTensorIdx](const unsigned int &idx) { return (idx == outDataTensorIdx); });
if (iter != postNode->inputIndex.end()) {
*iter = inDataTensorIdx;
}
}
}
// now all node's outputTensors are useless
// remove all node's outputTensors

View File

@ -7,9 +7,15 @@ set(CODER_SRC
${MICRO_DIR}/coder/train.cc
${MICRO_DIR}/coder/utils/coder_utils.cc
${MICRO_DIR}/coder/utils/dir_utils.cc
${MICRO_DIR}/coder/utils/train_utils.cc
${MICRO_DIR}/coder/utils/type_cast.cc
)
set(CODER_SRC ${CODER_SRC}
${MICRO_DIR}/coder/train/train_session.cc
${MICRO_DIR}/coder/train/train_generator.cc
)
set(CODER_ALLOCATOR_SRC
${MICRO_DIR}/coder/allocator/allocator.cc
${MICRO_DIR}/coder/allocator/memory_manager.cc
@ -18,7 +24,6 @@ set(CODER_ALLOCATOR_SRC
set(CODER_GENERATOR_SRC
${MICRO_DIR}/coder/generator/generator.cc
${MICRO_DIR}/coder/generator/inference/inference_generator.cc
${MICRO_DIR}/coder/generator/train/train_generator.cc
${MICRO_DIR}/coder/generator/component/common_component.cc
${MICRO_DIR}/coder/generator/component/weight_component.cc
${MICRO_DIR}/coder/generator/component/cmake_component.cc
@ -31,6 +36,7 @@ set(CODER_GENERATOR_SRC
${MICRO_DIR}/coder/generator/component/const_blocks/load_input.cc
${MICRO_DIR}/coder/generator/component/const_blocks/calib_output.cc
${MICRO_DIR}/coder/generator/component/const_blocks/benchmark.cc
${MICRO_DIR}/coder/generator/component/const_blocks/benchmark_train.cc
${MICRO_DIR}/coder/generator/component/const_blocks/mcontext.cc
)

View File

@ -26,16 +26,18 @@ const std::map<TypeId, size_t> size_map = {{kNumberTypeFloat, sizeof(float)},
{kNumberTypeInt32, sizeof(int32_t)}, {kNumberTypeInt16, sizeof(int16_t)},
{kNumberTypeInt8, sizeof(int8_t)}, {kNumberTypeUInt8, sizeof(uint8_t)}};
}
void *MemoryAllocator::MallocWeightTensor(TypeId type_id, size_t size, MallocType type) {
void *MemoryAllocator::MallocWeightTensor(TypeId type_id, size_t size, MallocType type,
const std::string &tensor_name) {
auto item = size_map.find(type_id);
MS_CHECK_TRUE_RET_NULL(item != size_map.end(), "unsupported type idnex");
MS_CHECK_TRUE_RET_NULL(item != size_map.end(), "unsupported type index");
size_t type_size = item->second;
MS_CHECK_TRUE_RET_NULL(type_size > 0, "type size should");
MS_CHECK_TRUE_RET_NULL(type_size > 0, "type size should be greater than 0");
std::vector<int> shape = {1, static_cast<int>(size / type_size)};
auto cate = type == kOfflinePackWeight ? lite::Category::CONST_TENSOR : lite::Category::VAR;
Tensor *weight = new (std::nothrow) lite::Tensor(type_id, shape, mindspore::NHWC, cate);
MS_CHECK_PTR_RET_NULL(weight);
weight->set_tensor_name(tensor_name);
std::string runtime_addr = kWeightPrefixName + std::to_string(weight_index_++);
malloc_weights_addr_.insert(std::make_pair(weight, runtime_addr));
if (type == kOfflinePackWeight) {
@ -88,6 +90,9 @@ std::map<Tensor *, std::string> MemoryAllocator::tensors_map() const {
std::map<Tensor *, std::string> res;
res.insert(tensors_addr_.begin(), tensors_addr_.end());
res.insert(malloc_weights_addr_.begin(), malloc_weights_addr_.end());
for (const auto &iter : saved_weights_addr_) {
res.insert({iter.second, iter.first});
}
return res;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -65,9 +65,9 @@ class MemoryAllocator {
* in view of weight, bias and workspace
*/
void *Malloc(TypeId type_id, size_t size, MallocType type) {
void *Malloc(TypeId type_id, size_t size, MallocType type, const std::string &tensor_name = "") {
if (type != kWorkspace) {
return MallocWeightTensor(type_id, size, type);
return MallocWeightTensor(type_id, size, type, tensor_name);
}
if (size == 0 || size >= UINT_MAX) {
return nullptr;
@ -123,7 +123,9 @@ class MemoryAllocator {
it = origin_weights_addr_.find(tensor);
if (it != origin_weights_addr_.end()) {
saved_weights_addr_.insert(std::make_pair(it->second, tensor));
if (immutable) malloc_weights_addr_.insert(std::make_pair(tensor, it->second));
if (immutable) {
malloc_weights_addr_.insert(std::make_pair(tensor, it->second));
}
return it->second;
}
MS_LOG(ERROR) << "uninitialized memory";
@ -138,7 +140,7 @@ class MemoryAllocator {
std::map<std::string, Tensor *> saved_weights() const { return saved_weights_addr_; }
size_t total_buffer_size() const { return tensors_size_ + workspace_size_; }
void enable_is_next() { is_next_ = true; }
void *MallocWeightTensor(TypeId type_id, size_t size, MallocType type);
void *MallocWeightTensor(TypeId type_id, size_t size, MallocType type, const std::string &tensor_name = "");
private:
int AssignTensors(const std::vector<std::unique_ptr<OperatorCoder>> &nodes);

View File

@ -18,17 +18,29 @@
#include <string>
#include <vector>
#include <map>
#include "tools/common/flag_parser.h"
#include "tools/converter/micro/coder/session.h"
#include "tools/converter/micro/coder/context.h"
#include "tools/converter/micro/coder/train/train_session.h"
#include "utils/dir_utils.h"
#include "securec/include/securec.h"
#include "src/common/file_utils.h"
#include "src/common/utils.h"
#include "tools/converter/micro/coder/config.h"
#include "tools/converter/micro/coder/generator/component/component.h"
namespace mindspore::lite::micro {
namespace {
std::shared_ptr<CoderSession> CreateCoderSession() {
std::shared_ptr<CoderSession> session;
auto code_mode = Configurator::GetInstance()->code_mode();
if (code_mode == CodeMode::Inference) {
session = std::make_shared<CoderSession>();
} else if (code_mode == CodeMode::Train) {
session = std::make_shared<CoderTrainSession>();
} else {
MS_LOG(ERROR) << "unsupported code mode. " << code_mode;
session = nullptr;
}
return session;
}
} // namespace
int Coder::Run(const void *model_buff, size_t size) {
session_ = CreateCoderSession();
if (session_ == nullptr) {
@ -133,6 +145,10 @@ int Coder::Init(const std::string &code_mode, const std::string &target, bool su
auto code_item = kCodeModeMap.find(code_mode);
MS_CHECK_TRUE_MSG(code_item != kCodeModeMap.end(), RET_ERROR, "unsupported code mode: " + code_mode);
config->set_code_mode(code_item->second);
if (code_item->second == CodeMode::Train && config->target() == kCortex_M) {
MS_LOG(ERROR) << "Cortex-M cannot support train.";
return RET_ERROR;
}
if (support_parallel && config->target() == kCortex_M) {
MS_LOG(ERROR) << "Cortex-M cannot support parallel.";

View File

@ -74,9 +74,17 @@ class CoderContext {
void set_graph_inputs(const std::vector<Tensor *> &graph_inputs) { graph_inputs_ = graph_inputs; }
void set_graph_outputs(const std::vector<Tensor *> &graph_outputs) { graph_outputs_ = graph_outputs; }
void set_graph_eval_outputs(const std::vector<Tensor *> &graph_eval_outputs) {
graph_eval_outputs_ = graph_eval_outputs;
}
void set_graph_train_outputs(const std::vector<Tensor *> &graph_train_outputs) {
graph_train_outputs_ = graph_train_outputs;
}
std::vector<Tensor *> graph_inputs() const { return graph_inputs_; }
std::vector<Tensor *> graph_outputs() const { return graph_outputs_; }
std::vector<Tensor *> graph_eval_outputs() const { return graph_eval_outputs_; }
std::vector<Tensor *> graph_train_outputs() const { return graph_train_outputs_; }
std::string input_name() { return input_name_; }
std::string output_name() { return output_name_; }
@ -107,6 +115,8 @@ class CoderContext {
private:
std::vector<Tensor *> graph_inputs_;
std::vector<Tensor *> graph_outputs_;
std::vector<Tensor *> graph_eval_outputs_;
std::vector<Tensor *> graph_train_outputs_;
// primitive const tensors, parsed from model, without packed.
std::map<std::string, Tensor *> saved_weights_;
// all tensors, include parsed from model and packed tensors.
@ -134,6 +144,7 @@ class CoderContext {
std::set<std::string> asm_files_;
// operator header files
std::set<std::string> h_files_;
// net.c's content, include the Inference and Training implementation
std::vector<std::string> code_blocks_;
std::vector<std::string> global_code_blocks_;

View File

@ -16,7 +16,6 @@
#include "coder/generator/component/common_component.h"
#include <memory>
#include "coder/generator/component/const_blocks/license.h"
#include "coder/generator/component/component.h"
#include "coder/utils/type_cast.h"
#include "coder/utils/coder_utils.h"
@ -28,6 +27,7 @@ namespace mindspore::lite::micro {
const char model_runtime_init_source[] = R"RAW(
typedef struct {
void *runtime_buffer;
bool train_mode; // true: train mode, false: eval mode
MSTensorHandleArray inputs;
MSTensorHandleArray outputs;
} MicroModel;
@ -58,6 +58,11 @@ void CodeMSModelCreate(std::ofstream &ofs, const std::unique_ptr<CoderContext> &
} else {
ofs << " micro_model->runtime_buffer = " << ctx->buffer_name() << ";\n";
}
if (config.code_mode() == CodeMode::Inference) {
ofs << " micro_model->train_mode = false;\n";
} else if (config.code_mode() == CodeMode::Train) {
ofs << " micro_model->train_mode = true;\n";
}
auto array_tostring = [&ofs](Tensor *tensor, const std::string &prefix, size_t index) {
ofs << kAlignedString << prefix << "_tensors[" << index << "] = malloc(sizeof(MicroTensor));\n";
ofs << kAlignedString << prefix << "_tensors[" << index << "]->type = " << EnumNameMSDataType(tensor->data_type())
@ -76,7 +81,11 @@ void CodeMSModelCreate(std::ofstream &ofs, const std::unique_ptr<CoderContext> &
};
std::vector<Tensor *> inputs = ctx->graph_inputs();
std::vector<Tensor *> outputs = ctx->graph_outputs();
if (config.code_mode() == CodeMode::Inference) {
outputs = ctx->graph_outputs();
} else if (config.code_mode() == CodeMode::Train) {
outputs = ctx->graph_train_outputs();
}
size_t inputs_size = inputs.size();
ofs << " MSTensorHandleArray model_inputs;\n";
ofs << " model_inputs.handle_num = " << inputs_size << ";\n";
@ -174,7 +183,7 @@ MSStatus MSModelPredict(MSModelHandle model, const MSTensorHandleArray inputs, M
ofs << " }\n";
ofs << " SetInputs(inputs_data_array, " << inputs_num << ");\n";
ofs << "\n";
ofs << " Inference();\n";
ofs << " Execute(micro_model->train_mode);\n";
ofs << "\n";
ofs << " void *outputs_data_array[" << outputs_num << "];\n";
ofs << " for (int i = 0; i < " << outputs_num << "; i++) {\n";
@ -191,7 +200,6 @@ void CodeCopyOutputsImplement(std::ofstream &ofs, const std::unique_ptr<CoderCon
auto tensor_map = ctx->tensors_map();
std::vector<Tensor *> outputs = ctx->graph_outputs();
size_t outputs_size = outputs.size();
ofs << "int CopyOutputsData(void **outputs, int num) {\n"
" if (outputs == NULL) {\n"
" return RET_ERROR;\n"
@ -350,11 +358,11 @@ void CodeFreeResourceImplement(std::ofstream &ofs, const std::unique_ptr<CoderCo
ofs << "}\n";
}
void CodeInferenceState(std::ofstream &ofs) {
void CodeExecuteState(std::ofstream &ofs) {
ofs << "/**\n"
<< " * net inference function\n"
<< " * net execute function\n"
<< " **/\n"
<< "void "
<< "Inference();\n\n";
<< "Execute(bool train_mode);\n\n";
}
} // namespace mindspore::lite::micro

View File

@ -49,6 +49,6 @@ void CodeInitResourceImplement(std::ofstream &ofs, const std::unique_ptr<CoderCo
void CodeFreeResourceImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx,
const Configurator &config);
void CodeInferenceState(std::ofstream &ofs);
void CodeExecuteState(std::ofstream &ofs);
} // namespace mindspore::lite::micro
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_GENERATOR_COMPONENT_COMMON_COMPONENT_H_

View File

@ -56,7 +56,7 @@ void usage() {
"args[3]: loop count for performance test\n"
"args[4]: calibration file\n"
"args[5]: runtime thread num, default is 1\n"
"args[6]: runtime thread bind mode, 0: No bind, 1: Bind hign cpu, 2: Bind mid cpu, default is 1\n"
"args[6]: runtime thread bind mode, 0: No bind, 1: Bind high cpu, 2: Bind mid cpu, default is 1\n"
"args[7]: warm up loop count, default is 3\n\n");
}
@ -218,7 +218,7 @@ int main(int argc, const char **argv) {
ret = MSModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL);
if (ret != kMSStatusSuccess) {
MSModelDestroy(&model_handle);
printf("MSModelPredict failed, ret: %d", kMSStatusSuccess);
printf("MSModelPredict failed, ret: %d", ret);
return ret;
}
}
@ -231,7 +231,7 @@ int main(int argc, const char **argv) {
ret = MSModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL);
if (ret != kMSStatusSuccess) {
MSModelDestroy(&model_handle);
printf("MSModelPredict failed, ret: %d", kMSStatusSuccess);
printf("MSModelPredict failed, ret: %d", ret);
return ret;
}
}

View File

@ -0,0 +1,308 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "coder/generator/component/const_blocks/benchmark_train.h"
namespace mindspore::lite::micro {
const char benchmark_train_source[] = R"RAW(/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "load_input.h"
#include "calib_output.h"
#include "c_api/types_c.h"
#include "c_api/model_c.h"
#include "c_api/context_c.h"
#include "src/tensor.h"
#include <time.h>
#include <inttypes.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define kMaxThreadNum 4
void usage() {
printf(
"-- mindspore benchmark_train params usage:\n"
"args[0]: executable file\n"
"args[1]: inputs binary file\n"
"args[2]: model weight binary file\n"
"args[3]: loop count for performance test\n"
"args[4]: calibration file\n"
"args[5]: runtime thread num, default is 1\n"
"args[6]: runtime thread bind mode, 0: No bind, 1: Bind high cpu, 2: Bind mid cpu, default is 1\n"
"args[7]: warm up loop count, default is 3\n\n");
}
uint64_t GetTimeUs() {
const int USEC = 1000000;
const int MSEC = 1000;
struct timespec ts = {0, 0};
if (clock_gettime(CLOCK_MONOTONIC, &ts) != 0) {
return 0;
}
uint64_t retval = (uint64_t)((ts.tv_sec * USEC) + (ts.tv_nsec / MSEC));
return retval;
}
void PrintTensorHandle(MSTensorHandle tensor) {
printf("name: %s, ", MSTensorGetName(tensor));
MSDataType data_type = MSTensorGetDataType(tensor);
printf("DataType: %d, ", data_type);
size_t element_num = (size_t)(MSTensorGetElementNum(tensor));
printf("Elements: %zu, ", element_num);
printf("Shape: [");
size_t shape_num = 0;
const int64_t *dims = MSTensorGetShape(tensor, &shape_num);
for (size_t i = 0; i < shape_num; i++) {
printf("%d ", (int)dims[i]);
}
printf("], Data: \n");
void *data = MSTensorGetMutableData(tensor);
element_num = element_num > 10 ? 10 : element_num;
switch (data_type) {
case kMSDataTypeNumberTypeFloat32: {
for (size_t i = 0; i < element_num; i++) {
printf("%.6f, ", ((float *)data)[i]);
}
printf("\n");
} break;
case kMSDataTypeNumberTypeFloat16:
case kMSDataTypeNumberTypeInt16: {
for (size_t i = 0; i < element_num; i++) {
printf("%" PRId16, ((int16_t *)data)[i]);
}
printf("\n");
} break;
case kMSDataTypeNumberTypeInt32: {
for (size_t i = 0; i < element_num; i++) {
printf("%" PRId32, ((int32_t *)data)[i]);
}
printf("\n");
} break;
case kMSDataTypeNumberTypeInt8: {
for (size_t i = 0; i < element_num; i++) {
printf("%" PRIi8, ((int8_t *)data)[i]);
}
printf("\n");
} break;
case kMSDataTypeNumberTypeUInt8: {
for (size_t i = 0; i < element_num; i++) {
printf("%u", ((uint8_t *)data)[i]);
}
printf("\n");
} break;
default:
printf("Unsupported data type to print");
break;
}
}
int main(int argc, const char **argv) {
if (argc < 2) {
printf("input command is invalid\n");
usage();
return kMSStatusLiteError;
}
printf("=======run benchmark_train======\n");
MSContextHandle ms_context_handle = NULL;
if (argc >= 6) {
int thread_num = atoi(argv[5]);
if (thread_num < 1 || thread_num > kMaxThreadNum) {
printf("Thread number error! It should be greater than 0 and less than 5\n");
return kMSStatusLiteParamInvalid;
}
int bind_mode = 1;
if (argc >= 7) {
bind_mode = atoi(argv[6]);
if (bind_mode < 0 || bind_mode > 2) {
printf("Thread bind mode error! 0: No bind, 1: Bind hign cpu, 2: Bind mid cpu.\n");
return kMSStatusLiteParamInvalid;
}
}
ms_context_handle = MSContextCreate();
if (ms_context_handle) {
MSContextSetThreadNum(ms_context_handle, thread_num);
MSContextSetThreadAffinityMode(ms_context_handle, bind_mode);
}
printf("context: ThreadNum: %d, BindMode: %d\n", thread_num, bind_mode);
}
void *model_buffer = NULL;
int model_size = 0;
// read .bin file by ReadBinaryFile;
if (argc >= 3) {
model_buffer = ReadInputData(argv[2], &model_size);
}
MSModelHandle model_handle = MSModelCreate();
int ret = MSModelBuild(model_handle, model_buffer, model_size, kMSModelTypeMindIR, ms_context_handle);
MSContextDestroy(&ms_context_handle);
if (ret != kMSStatusSuccess) {
printf("MSModelBuildFromFile failed, ret: %d\n", ret);
free(model_buffer);
model_buffer = NULL;
return ret;
}
if (model_buffer) {
free(model_buffer);
model_buffer = NULL;
}
// set model inputs tensor data
MSTensorHandleArray inputs_handle = MSModelGetInputs(model_handle);
if (inputs_handle.handle_list == NULL) {
printf("MSModelGetInputs failed, ret: %d", ret);
return ret;
}
size_t inputs_num = inputs_handle.handle_num;
void *inputs_binbuf[inputs_num];
int inputs_size[inputs_num];
for (size_t i = 0; i < inputs_num; ++i) {
MSTensorHandle tensor = inputs_handle.handle_list[i];
inputs_size[i] = (int)MSTensorGetDataSize(tensor);
}
ret = ReadInputsFile((char *)(argv[1]), inputs_binbuf, inputs_size, (int)inputs_num);
if (ret != 0) {
MSModelDestroy(&model_handle);
return ret;
}
for (size_t i = 0; i < inputs_num; ++i) {
void *input_data = MSTensorGetMutableData(inputs_handle.handle_list[i]);
memcpy(input_data, inputs_binbuf[i], inputs_size[i]);
free(inputs_binbuf[i]);
inputs_binbuf[i] = NULL;
}
MSTensorHandleArray outputs_handle = MSModelGetOutputs(model_handle);
if (!outputs_handle.handle_list) {
printf("MSModelGetOutputs failed, ret: %d", ret);
return ret;
}
int warm_up_loop_count = 3;
if (argc >= 8) {
warm_up_loop_count = atoi(argv[7]);
if (warm_up_loop_count < 0) {
printf("The warm up loop count error! Cannot be less than 0.\n");
return kMSStatusLiteParamInvalid;
}
}
printf("Running warm up loops...\n");
for (int i = 0; i < warm_up_loop_count; ++i) {
ret = MSModelRunStep(model_handle, NULL, NULL);
if (ret != kMSStatusSuccess) {
MSModelDestroy(&model_handle);
printf("MSModelRunStep failed, ret: %d", ret);
return ret;
}
}
if (argc >= 4) {
int loop_count = atoi(argv[3]);
printf("\nloop count: %d\n", loop_count);
uint64_t start_time = GetTimeUs();
for (int i = 0; i < loop_count; ++i) {
ret = MSModelRunStep(model_handle, NULL, NULL);
if (ret != kMSStatusSuccess) {
MSModelDestroy(&model_handle);
printf("MSModelRunStep failed, ret: %d", ret);
return ret;
}
}
uint64_t end_time = GetTimeUs();
float total_time = (float)(end_time - start_time) / 1000.0f;
printf("total time: %.5fms, per time: %.5fms\n", total_time, total_time / loop_count);
}
ret = MSModelRunStep(model_handle, NULL, NULL);
if (ret != kMSStatusSuccess) {
MSModelDestroy(&model_handle);
printf("MSModelRunStep failed, ret: %d", ret);
return ret;
}
printf("========run train mode success=======\n");
printf("outputs: \n");
for (size_t i = 0; i < outputs_handle.handle_num; i++) {
MSTensorHandle output = outputs_handle.handle_list[i];
PrintTensorHandle(output);
}
ret = MSModelSetTrainMode(model_handle, false); // when change train mode, outputs handle needs to be refreshed
if (ret != kMSStatusSuccess) {
MSModelDestroy(&model_handle);
printf("MSModelSetTrainMode failed, ret: %d", ret);
return ret;
}
outputs_handle = MSModelGetOutputs(model_handle);
if (!outputs_handle.handle_list) {
printf("MSModelGetOutputs failed, ret: %d", ret);
return ret;
}
ret = MSModelRunStep(model_handle, NULL, NULL);
if (ret != kMSStatusSuccess) {
MSModelDestroy(&model_handle);
printf("MSModelRunStep failed, ret: %d", ret);
return ret;
}
printf("\n========run eval mode success=======\n");
printf("outputs: \n");
for (size_t i = 0; i < outputs_handle.handle_num; i++) {
MSTensorHandle output = outputs_handle.handle_list[i];
PrintTensorHandle(output);
}
if (argc >= 5) {
CalibTensor *calib_tensors;
int calib_num = 0;
ret = ReadCalibData(argv[4], &calib_tensors, &calib_num);
if (ret != kMSStatusSuccess) {
MSModelDestroy(&model_handle);
return ret;
}
ret = CompareOutputs(outputs_handle, &calib_tensors, calib_num);
if (ret != kMSStatusSuccess) {
MSModelDestroy(&model_handle);
return ret;
}
FreeCalibTensors(&calib_tensors, calib_num);
}
ret = MSModelExportWeight(model_handle, "./export.bin");
if (ret != kMSStatusSuccess) {
MSModelDestroy(&model_handle);
printf("MSModelExportWeight failed, ret: %d", ret);
return ret;
}
printf("========export weight success=======\n");
printf("========run success=======\n");
MSModelDestroy(&model_handle);
return kMSStatusSuccess;
}
)RAW";
} // namespace mindspore::lite::micro

View File

@ -0,0 +1,23 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_GENERATOR_COMPONENT_CONST_BLOCKS_BENCHMARK_TRAIN_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_GENERATOR_COMPONENT_CONST_BLOCKS_BENCHMARK_TRAIN_H_
namespace mindspore::lite::micro {
extern const char benchmark_train_source[];
} // namespace mindspore::lite::micro
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_GENERATOR_COMPONENT_CONST_BLOCKS_BENCHMARK_TRAIN_H_

View File

@ -143,6 +143,7 @@ int ReadInputsFile(char *path, void **buffers, const int *inputs_size, int input
buffers[i] = ReadInputData(inputs_path[i], &size);
if (size != inputs_size[i] || buffers[i] == NULL) {
printf("size mismatch, %s, input: %d, needed: %d\n", inputs_path[i], size, inputs_size[i]);
free(buffers[i]);
return kMSStatusLiteError;
}
}

View File

@ -22,7 +22,7 @@ void MSTensorHandleArrayDestroy(MSTensorHandleArray inputs) {
if (inputs.handle_list == NULL) {
return;
}
for (int i = 0; i < inputs.handle_num; i++) {
for (size_t i = 0; i < inputs.handle_num; i++) {
MicroTensor *micro_tensor = inputs.handle_list[i];
if (!micro_tensor) {
continue;

View File

@ -16,166 +16,161 @@
#include "coder/generator/component/train_component.h"
#include <string>
#include "nnacl/op_base.h"
#include "coder/utils/type_cast.h"
namespace mindspore::lite::micro {
void CodeTrainParams(std::ofstream &ofs) {
ofs << "struct TrainParameter {\n"
" float beta1_;\n"
" float beta2_;\n"
" float epsilon_;\n"
"};\n"
"\n"
"enum EarlyStopType {\n"
" Diff = 0,\n"
" WeigthDiff = 1,\n"
" Abs = 2,\n"
"};\n"
"\n"
"struct EarlyStop {\n"
" enum EarlyStopType type;\n"
" float tolerate;\n"
"};\n\n";
}
void CodeMSModelSetTrainMode(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) {
std::vector<Tensor *> train_outputs = ctx->graph_train_outputs();
std::vector<Tensor *> eval_outputs = ctx->graph_eval_outputs();
auto train_outputs_size = train_outputs.size();
auto eval_outputs_size = eval_outputs.size();
void CodeFeaturesState(std::ofstream &ofs) {
ofs << "/**\n"
" *\n"
" * @param size, return the number of features\n"
" * @return, the address of features\n"
" */\n"
<< "FeatureParam *GetFeatures(int *size);\n\n";
ofs << "/**\n"
" *\n"
" * @param features, the address of features\n"
" * @param size, the number of features\n"
" * @return, status\n"
" */\n"
<< "int UpdateFeatures(FeatureParam *features, int size);\n\n";
}
void CodeFeaturesImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) {
size_t features_num = 0;
ofs << "static FeatureParam feature_params[] = {\n";
for (const auto &item : ctx->saved_weights()) {
std::string addr = item.first;
Tensor *tensor = item.second;
if (tensor->tensor_name().empty()) {
MS_LOG(ERROR) << "exist empty feature";
continue;
auto array_tostring = [&ofs](Tensor *tensor, size_t index) {
ofs << " output_tensors[" << index << "] = malloc(sizeof(MicroTensor));\n";
ofs << " output_tensors[" << index << "]->type = " << EnumNameMSDataType(tensor->data_type()) << ";\n";
ofs << " output_tensors[" << index << "]->format = kMSFormatNHWC;\n";
ofs << " output_tensors[" << index << "]->ndim = " << tensor->shape().size() << ";\n";
size_t shape_size = tensor->shape().size();
ofs << " output_tensors[" << index << "]->shape = "
<< "malloc(" << shape_size << " * sizeof(int64_t));\n";
for (size_t i = 0; i < shape_size; i++) {
ofs << " output_tensors[" << index << "]->shape[" << i << "]= " << tensor->shape()[i] << ";\n";
}
ofs << "\t{\"" << tensor->tensor_name() << "\", " << addr << ", " << tensor->ElementsNum() << ", "
<< EnumMicroTensorDataType(tensor->data_type()) << "}, \n";
features_num++;
ofs << " output_tensors[" << index << "]->name = \"" << tensor->tensor_name() << "\";\n";
ofs << " output_tensors[" << index << "]->data = NULL;\n";
};
ofs << R"RAW(
MSStatus MSModelSetTrainMode(MSModelHandle model, bool train) {
MicroModel *micro_model = (MicroModel *)model;
if (micro_model == NULL) {
return kMSStatusLiteNullptr;
}
ofs << "};\n";
ofs << "FeatureParam *GetFeatures(int *size) {\n"
<< " *size = " << features_num << ";\n"
<< " return feature_params;\n"
"}\n\n";
ofs << "int "
<< "UpdateFeatures(FeatureParam *features, int size) {\n"
<< " for (int i = 0; i < size; ++i) {\n"
" FeatureParam *src = features + i;\n"
" FeatureParam dst;\n"
" // find the dst feature\n"
" bool is_find = false;\n"
<< " for (int j = 0; j < " << features_num << "; ++j) {\n"
<< " if (strcmp(src->name, feature_params[j].name) == 0) {\n"
" dst = feature_params[j];\n"
" is_find = true;\n"
" break;\n"
" }\n"
" }\n"
" if (!is_find) {\n"
" MICRO_ERROR(\"invalid feature param: %s\", src->name);\n"
" return RET_ERROR;\n"
" }\n"
" if (src->elenums != dst.elenums) {\n"
" MICRO_ERROR(\"feature %s elenums is mismatch, src: %lu, dst: %lu\", src->name, src->elenums, "
"dst.elenums);\n"
" return RET_ERROR;\n"
" }\n"
" memcpy(dst.data, src->data, src->elenums * sizeof(float));\n"
" }\n"
" MICRO_INFO(\"update features map success\");\n"
" return RET_OK;\n"
micro_model->train_mode = train;
MSTensorHandleArrayDestroy(micro_model->outputs);
)RAW";
ofs << " if (train) {\n"
<< " MSTensorHandleArray model_outputs;\n"
<< " model_outputs.handle_num = " << train_outputs_size << ";\n"
<< " MicroTensor **output_tensors = malloc(" << train_outputs_size << " * sizeof(MicroTensor *));\n"
<< " model_outputs.handle_list = (MSTensorHandle *)(output_tensors);\n"
<< " micro_model->outputs = model_outputs;\n";
for (size_t i = 0; i < train_outputs_size; ++i) {
Tensor *output = train_outputs[i];
array_tostring(output, i);
}
ofs << " } else {\n"
<< " MSTensorHandleArray model_outputs;\n"
<< " model_outputs.handle_num = " << eval_outputs_size << ";\n"
<< " MicroTensor **output_tensors = malloc(" << eval_outputs_size << " * sizeof(MicroTensor *));\n"
<< " model_outputs.handle_list = (MSTensorHandle *)(output_tensors);\n"
<< " micro_model->outputs = model_outputs;\n";
for (size_t i = 0; i < eval_outputs_size; ++i) {
Tensor *output = eval_outputs[i];
array_tostring(output, i);
}
ofs << " }\n"
<< " return kMSStatusSuccess;\n"
"}\n\n";
}
void CodeTrainState(std::ofstream &ofs) {
ofs
<< "/**\n"
" * Train Function\n"
" * @param epoch, the train epoch\n"
" * @param iterations, which is equal to batch_num, the number of iterations of each epoch\n"
" * @param use_train_param, default parameters already exists, such as the momentum, user can update these\n"
" * parameters to improve the accuracy\n"
" * @param parameter, the TrainParameter contains epsilon/beta1/beta2\n"
" * @return status\n"
" */\n"
<< "int Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter *parameter, "
"const struct EarlyStop *early_stop);\n\n";
}
void CodeTrainImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) {
std::vector<Tensor *> inputs = ctx->graph_inputs();
size_t inputs_num = inputs.size();
auto inputs_tostring = [&inputs, &ctx]() {
std::string result;
result += "{";
for (size_t i = 0; i < inputs.size(); ++i) {
result += ctx->input_name() + std::to_string(i) + ", ";
void CodeMSModelRunStep(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) {
auto inputs_size = ctx->graph_inputs().size();
size_t train_outputs_size = ctx->graph_train_outputs().size();
size_t eval_outputs_size = ctx->graph_eval_outputs().size();
ofs << R"RAW(
MSStatus MSModelRunStep(MSModelHandle model, const MSKernelCallBackC before, const MSKernelCallBackC after) {
MicroModel *micro_model = (MicroModel *)model;
if (micro_model == NULL) {
return kMSStatusLiteNullptr;
}
result += "}";
return result;
};
auto wrap = [](size_t i) { return "[" + std::to_string(i) + "]"; };
auto offset_inputs = [&inputs, &wrap]() {
std::string src = "origin_inputs";
std::string dst = "input_ptr";
std::string result;
for (size_t i = 0; i < inputs.size(); ++i) {
result += dst + wrap(i) += " = " + src + wrap(i) + " + j * " + std::to_string(inputs[i]->Size()) + ";\n";
}
return result;
};
auto varify_inputs = [&inputs, &wrap]() {
std::string result;
for (size_t i = 0; i < inputs.size(); ++i) {
result += "origin_input" + wrap(i) + " + iterations * " + std::to_string(inputs[i]->Size()) + " == NULL";
i < inputs.size() - 1 ? result += " || " : result += "";
}
return result;
};
ofs << "int Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter "
"*parameter, const struct EarlyStop *early_stop) {\n"
" if (iterations <= 0 || epoch <= 0) {\n"
" MICRO_ERROR(\"error iterations or epoch!, epoch:%d, iterations:%d\", epoch, iterations);\n"
" return RET_ERROR;\n"
" }\n"
" MICRO_INFO(\"train epoch: %d, batch_num: %d\", epoch, iterations);\n"
<< " const void *origin_input[] = " << inputs_tostring() << ";\n";
ofs << " if (" << varify_inputs() << ") {\n"
<< " MICRO_ERROR(\"input data is invalid, epoch: %d, iterations: %d\", epoch, iterations);\n"
" return RET_ERROR;\n"
" }\n";
ofs << " for (int i = 0; i < epoch; ++i) {\n"
<< " const void *input_ptr[" << inputs_num << "];\n"
<< " float loss = 0;\n"
<< " for (int j = 0; j < iterations; ++j) {\n"
<< " " << offset_inputs() << "\n"
<< " "
<< "_SetInputs(input_ptr, " << inputs_num << ");\n"
<< " "
<< "_Inference();\n"
<< " loss = "
<< "ComputeLossAndGradient();\n"
bool train_mode = micro_model->train_mode;
)RAW";
ofs << " if (micro_model->inputs.handle_num != " << inputs_size << ") {\n"
<< " return kMSStatusLiteParamInvalid;\n"
<< " }\n"
" }\n"
" return RET_OK;\n"
"};\n\n";
<< " const void *inputs_data_array[" << inputs_size << "];\n"
<< " for (int i = 0; i < " << inputs_size << "; i++) {\n"
<< " inputs_data_array[i] = ((MicroTensor *)(micro_model->inputs.handle_list[i]))->data;\n"
<< " }\n"
<< " SetInputs(inputs_data_array, " << inputs_size << ");\n"
<< "\n"
<< " Execute(train_mode);\n\n"
<< " // copy data to outputs handle\n"
<< " if (train_mode) {\n"
<< " if (micro_model->outputs.handle_num != " << train_outputs_size << ") {\n"
<< " return kMSStatusLiteParamInvalid;\n"
<< " }\n"
<< " void *outputs_data_array[" << train_outputs_size << "];\n"
<< " for (int i = 0; i < " << train_outputs_size << "; i++) {\n"
<< " outputs_data_array[i] = MSTensorGetMutableData(micro_model->outputs.handle_list[i]);\n"
<< " }\n"
<< " if (CopyOutputsDataWithFlag(outputs_data_array, " << train_outputs_size << ", true) != RET_OK) {\n"
<< " return kMSStatusLiteError;\n"
<< " }\n"
<< " } else {\n"
<< " if (micro_model->outputs.handle_num != " << eval_outputs_size << ") {\n"
<< " return kMSStatusLiteParamInvalid;\n"
<< " }\n"
<< " void *outputs_data_array[" << eval_outputs_size << "];\n"
<< " for (int i = 0; i < " << eval_outputs_size << "; i++) {\n"
<< " outputs_data_array[i] = MSTensorGetMutableData(micro_model->outputs.handle_list[i]);\n"
<< " }\n"
<< " if (CopyOutputsDataWithFlag(outputs_data_array, " << eval_outputs_size << ", false) != RET_OK) {\n"
<< " return kMSStatusLiteError;\n"
<< " }\n"
<< " }\n"
<< " return kMSStatusSuccess;\n"
<< "}\n\n";
}
void CodeMSModelExportWeight(std::ofstream &ofs) {
ofs << R"RAW(
MSStatus MSModelExportWeight(MSModelHandle model, const char *export_path) {
int ret = Export(export_path);
return ret == RET_OK ? kMSStatusSuccess : kMSStatusLiteError;
})RAW";
ofs << "\n\n";
}
void CodeCopyTrainOutputsState(std::ofstream &ofs) {
ofs << "int CopyOutputsDataWithFlag(void **outputs, int num, bool train_mode);\n\n";
}
void CodeCopyTrainOutputsImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) {
auto tensor_map = ctx->tensors_map();
std::vector<Tensor *> train_outputs = ctx->graph_train_outputs();
std::vector<Tensor *> eval_outputs = ctx->graph_eval_outputs();
size_t train_outputs_size = train_outputs.size();
size_t eval_outputs_size = eval_outputs.size();
ofs << "int CopyOutputsDataWithFlag(void **outputs, int num, bool train_mode) {\n"
<< " if (outputs == NULL) {\n"
<< " return RET_ERROR;\n"
<< " }\n"
<< " if (train_mode) {\n"
<< " if (num != " << train_outputs_size << ") {\n"
<< " return RET_ERROR;\n"
<< " }\n";
for (size_t i = 0; i < train_outputs_size; ++i) {
Tensor *output = train_outputs[i];
MS_CHECK_PTR_IF_NULL(output);
MS_CHECK_TRUE_RET_VOID(tensor_map.find(output) != tensor_map.end());
ofs << " memcpy(outputs[" << i << "], " << tensor_map[output] << ", " << output->Size() << ");\n";
}
ofs << " } else {\n"
<< " if (num != " << eval_outputs_size << ") {\n"
<< " return RET_ERROR;\n"
<< " }\n";
for (size_t i = 0; i < eval_outputs_size; ++i) {
Tensor *output = eval_outputs[i];
MS_CHECK_PTR_IF_NULL(output);
MS_CHECK_TRUE_RET_VOID(tensor_map.find(output) != tensor_map.end());
ofs << " memcpy(outputs[" << i << "], " << tensor_map[output] << ", " << output->Size() << ");\n";
}
ofs << " }\n"
<< " return RET_OK;\n"
"}\n\n";
}
} // namespace mindspore::lite::micro

View File

@ -26,12 +26,10 @@
#include "tools/converter/micro/coder/context.h"
namespace mindspore::lite::micro {
void CodeTrainParams(std::ofstream &ofs);
void CodeFeaturesState(std::ofstream &ofs);
void CodeFeaturesImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx);
void CodeTrainState(std::ofstream &ofs);
void CodeTrainImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx);
void CodeMSModelSetTrainMode(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx);
void CodeMSModelRunStep(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx);
void CodeMSModelExportWeight(std::ofstream &ofs);
void CodeCopyTrainOutputsState(std::ofstream &ofs);
void CodeCopyTrainOutputsImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx);
} // namespace mindspore::lite::micro
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_GENERATOR_COMPONENT_TRAIN_COMPONENT_H_

View File

@ -83,27 +83,36 @@ void CodeModelParamsForNet(std::ofstream &hofs, std::ofstream &cofs, const std::
}
if (CheckConstantTensor(tensor)) {
if (config.target() != kCortex_M) {
hofs << "extern " << GetTensorDataType(tensor->data_type()) << name << "[];\n";
hofs << "extern " << GetTensorDataType(tensor->data_type()) << name << "[]; // " << tensor->tensor_name()
<< std::endl;
cofs << GetTensorDataType(tensor->data_type()) << name << "[" << tensor->ElementsNum() << "];\n";
} else {
hofs << "extern const " << GetTensorDataType(tensor->data_type()) << name << "[];\n";
hofs << "extern const " << GetTensorDataType(tensor->data_type()) << name << "[]; // " << tensor->tensor_name()
<< std::endl;
}
} else if (tensor->category() == lite::Category::VAR) {
hofs << "extern " << GetTensorDataType(tensor->data_type()) << "*" << name << ";\n";
hofs << "extern " << GetTensorDataType(tensor->data_type()) << "*" << name << "; // " << tensor->tensor_name()
<< std::endl;
cofs << GetTensorDataType(tensor->data_type()) << "*" << name << " = NULL;\n";
}
}
cofs << "\n";
}
void CodeInitWeightState(std::ofstream &ofs) {
ofs << "/**\n"
<< " * @param weight_buffer, the address of the weight binary file\n"
<< " * @param weight_size, the size of the model file in bytes\n"
<< " **/\n"
void CodeInitWeightState(std::ofstream &ofs, const Configurator &config) {
ofs << "/// \\brief Init model weight from buffer.\n\n"
<< "/// \\param[in] weight_buffer The address of the weight binary file.\n"
<< "/// \\param[in] weight_size The size of the weight file in bytes.\n"
<< "int Init(void *weight_buffer, int weight_size);\n\n";
}
void CodeExportWeightState(std::ofstream &ofs, const Configurator &config) {
ofs << "/// \\brief Export model weight to the specified path.\n\n"
<< "/// \\param[in] output_weight_file The path of the export weight file.\n\n"
<< "/// \\return 0 on success or -1 in case of error.\n"
<< "int Export(const char* output_weight_file);\n\n";
}
void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) {
if (config.target() != kCortex_M) {
ofs << "static size_t PackWeightSize() {\n";
@ -114,20 +123,15 @@ void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr<CoderContext>
ofs << " return w_size;\n";
ofs << "}\n\n";
ofs << "int Init(void *weight_buffer, int weight_size) {\n"
<< " if (weight_buffer == NULL) {\n"
<< " return RET_ERROR;\n"
<< " }\n";
ofs << " struct ModelParameter {\n"
ofs << "struct ModelParameter {\n"
<< " void *addr;\n"
<< " size_t size;\n"
<< " size_t offset;\n"
<< " };\n";
<< "};\n\n";
ofs << " size_t " << ctx->weight_size_name() << " = PackWeightSize();\n";
// generate weight struct array
size_t params_num = 0;
size_t offset = 0;
std::string params;
std::string origins;
for (const auto &item : ctx->saved_weights()) {
std::string name = item.first;
@ -135,21 +139,21 @@ void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr<CoderContext>
if (!CheckConstantTensor(tensor)) {
continue;
}
std::map<Tensor *, std::string> ctx_tensor_map = ctx->tensors_map();
auto iter = ctx_tensor_map.find(tensor);
if (iter != ctx_tensor_map.end()) {
origins += " {" + name + ", " + std::to_string(tensor->Size()) + ", " + std::to_string(offset) + "},\n";
params_num++;
} else {
TypeId data_type = tensor->data_type();
params +=
" " + GetTensorDataType(data_type) + "*" + name + " = (weight_buffer + " + std::to_string(offset) + ");\n";
}
offset += tensor->Size();
}
ofs << params << "\n";
ofs << " struct ModelParameter model_params[] = {\n" << origins << " };\n";
ofs << "struct ModelParameter model_params[] = {\n" << origins << " };\n";
ofs << "\n";
// generate weight init function
ofs << "int Init(void *weight_buffer, int weight_size) {\n"
<< " if (weight_buffer == NULL) {\n"
<< " return RET_ERROR;\n"
<< " }\n";
ofs << " size_t " << ctx->weight_size_name() << " = PackWeightSize();\n";
ofs << " for(int i = 0; i < " << params_num << "; ++i) {\n"
<< " if (model_params[i].offset + model_params[i].size > weight_size) {\n"
" return RET_ERROR;\n"
@ -165,6 +169,8 @@ void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr<CoderContext>
ofs << "int Init(void *weight_buffer, int weight_size) {\n";
ofs << " const size_t w_size = " << ctx->weight_buffer_size() << ";\n";
}
// generate matrix weight init func
ofs << " size_t " << ctx->weight_offset_name() << " = 0;\n";
for (const auto &block : ctx->init_contents()) {
ofs << "{\n" << block << "}\n";
@ -175,6 +181,29 @@ void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr<CoderContext>
ofs << "}\n\n";
}
void CodeWeightExportFunc(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) {
if (config.target() == kCortex_M) {
MS_LOG(DEBUG) << "weight file is unsupported to export when in Cortex M mode.";
return;
}
ofs << "int Export(const char* output_weight_file) {\n"
<< " if (output_weight_file == NULL) {\n"
<< " return RET_ERROR;\n"
<< " }\n\n"
<< " FILE *fp;\n"
<< " if((fp = fopen(output_weight_file, \"wb\")) == NULL) {\n"
<< " printf(\"open file failed.\");\n"
<< " return RET_ERROR;\n"
<< " }\n"
<< " int params_len = sizeof(model_params) / sizeof(model_params[0]);\n"
<< " for (int i = 0; i < params_len; ++i) {\n"
<< " fwrite(model_params[i].addr, sizeof(char), model_params[i].size, fp);\n"
<< " }\n"
<< " fclose(fp);\n"
<< " return RET_OK;\n"
<< "}\n";
}
void SaveDataToNet(const std::map<std::string, Tensor *> &saved_weights, const std::string &net_file) {
std::ofstream net(net_file, std::ios::out | std::ios::trunc | std::ios::binary);
MS_CHECK_TRUE_WITHOUT_RET(net.is_open(), "net file open failed!");

View File

@ -35,7 +35,9 @@ void SaveDataToNet(const std::map<std::string, Tensor *> &saved_weights, const s
void CodeModelParamsForNet(std::ofstream &hofs, std::ofstream &cofs, const std::unique_ptr<CoderContext> &ctx,
const Configurator &config);
void CodeInitWeightState(std::ofstream &ofs);
void CodeInitWeightState(std::ofstream &ofs, const Configurator &config);
void CodeExportWeightState(std::ofstream &ofs, const Configurator &config);
void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config);
void CodeWeightExportFunc(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config);
} // namespace mindspore::lite::micro
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_GENERATOR_COMPONENT_WEIGHT_COMPONENT_H_

View File

@ -29,7 +29,9 @@
#include "coder/generator/component/const_blocks/mtensor.h"
#include "coder/generator/component/const_blocks/mcontext.h"
#include "coder/generator/component/const_blocks/benchmark.h"
#include "coder/generator/component/const_blocks/benchmark_train.h"
#include "coder/generator/component/const_blocks/license.h"
#include "coder/generator/component/train_component.h"
#include "coder/log.h"
#include "coder/opcoders/parallel.h"
#include "coder/opcoders/kernel_registry.h"
@ -61,26 +63,6 @@ Generator::Generator(std::unique_ptr<CoderContext> ctx) {
Generator::~Generator() { (void)umask(origin_umask_); }
void Generator::CodeNetRunFunc(std::ofstream &ofs) {
// generate net inference code
ofs << "void Inference() {\n";
if (config_->support_parallel()) {
ofs << " " << gThreadNum << " = GetCurrentThreadNum();\n";
ofs << " SetSpinCountMaxValue();\n";
}
for (const auto &block : ctx_->code_blocks()) {
ofs << " {\n" << block << " }\n";
}
for (const auto &block : ctx_->after_inference_code_blocks()) {
ofs << block << "\n";
}
if (config_->support_parallel()) {
ofs << " SetSpinCountMinValue();\n";
}
ofs << "}\n";
}
int Generator::CodeSourceCMakeFile() {
std::string src_cmake_file = net_src_file_path_ + cmake_file_name_;
std::ofstream ofs(src_cmake_file);
@ -100,7 +82,7 @@ int Generator::CodeDataCFile() {
cofs << "#include \"data.h\"\n";
auto inputs_num = ctx_->graph_inputs().size();
auto outputs_num = ctx_->graph_outputs().size();
auto outputs_num = ctx_->graph_eval_outputs().size();
cofs << "#define NET_INPUTS_NUM " << inputs_num << "\n";
cofs << "#define NET_OUTPUTS_NUM " << outputs_num << "\n";
@ -129,7 +111,7 @@ int Generator::CodeDataCFile() {
<< " },\n";
}
for (size_t i = 0; i < outputs_num; i++) {
Tensor *tensor = ctx_->graph_outputs()[i];
Tensor *tensor = ctx_->graph_eval_outputs()[i];
cofs << "#define NET_OUTPUT" << i << "_SIZE " << tensor->ElementsNum() << "\n";
data_def << "float output" << i << "_data[NET_OUTPUT" << 0 << "_SIZE];\n";
calib_data_def << "float calib_output" << i << "_data[NET_OUTPUT" << 0 << "_SIZE] = {};\n";
@ -174,6 +156,9 @@ int Generator::CodeStaticContent() {
std::string context_source_txt = context_source;
std::string tensor_header_txt = tensor_header;
std::string tensor_source_txt = tensor_source;
if (config_->code_mode() == CodeMode::Train) {
benchmark_source_txt = benchmark_train_source;
}
if (config_->target() == kCortex_M) {
bench_cmake_lists_txt = bench_cmake_lists_cortex;
calib_header_txt = calib_header_cortex;
@ -233,7 +218,13 @@ int Generator::CodeMSModelImplement() {
CodeMSModelCreate(ofs, ctx_, *config_);
CodeMSModelBuild(ofs, config_);
ofs << model_runtime_other_source;
if (config_->code_mode() == CodeMode::Train) {
CodeMSModelRunStep(ofs, ctx_);
CodeMSModelSetTrainMode(ofs, ctx_);
CodeMSModelExportWeight(ofs);
} else {
CodeMSModelPredict(ofs, ctx_);
}
CodeMSModelDestory(ofs, config_);
return RET_OK;
}
@ -253,6 +244,7 @@ int Generator::CodeWeightFile() {
MS_LOG(INFO) << "write " << cfile;
cofs << g_hwLicense;
cofs << "#include \"" << net_weight_hfile_ << "\"\n\n";
cofs << "#include <stdio.h>\n\n";
cofs << "int " << gThreadNum << " = 1; \n";
std::vector<Tensor *> inputs = ctx_->graph_inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
@ -275,14 +267,47 @@ int Generator::CodeWeightFile() {
CodeModelParamsData(cofs, ctx_->saved_weights());
}
CodeModelParamsForNet(hofs, cofs, ctx_, *config_);
CodeInitWeightState(hofs);
CodeInitWeightState(hofs, *config_);
CodeWeightInitFunc(cofs, ctx_, *config_);
if (config_->code_mode() == CodeMode::Train) {
CodeExportWeightState(hofs, *config_);
CodeWeightExportFunc(cofs, ctx_, *config_);
}
hofs.close();
cofs.close();
return RET_OK;
}
void Generator::CodeCommonNetH(std::ofstream &ofs) {
ofs << g_hwLicense;
ofs << kExternCpp;
CodeInputState(ofs);
if (is_get_quant_args_) {
CodeGraphQuantArgsState(ofs);
}
CodeManageResourceState(ofs);
CodeExecuteState(ofs);
}
void Generator::CodeCommonNetC(std::ofstream &ofs) {
ofs << g_hwLicense << "\n"
<< "#include \"" << net_weight_hfile_ << "\"\n"
<< "#include \"" << net_inc_hfile_ << "\"\n\n";
if (config_->support_parallel()) {
ofs << "#include \"" << kThreadWrapper << "\"\n\n";
}
if (config_->debug_mode()) {
ofs << "#include \"" << kDebugUtils << "\"\n";
}
CodeGlobalCodeBlocks(ofs, ctx_);
CodeInputImplement(ofs, ctx_);
CodeInitResourceImplement(ofs, ctx_);
CodeFreeResourceImplement(ofs, ctx_, *config_);
if (is_get_quant_args_) {
CodeGraphQuantArgsImplement(ofs, ctx_);
}
}
int Generator::CodeRegKernelHFile() {
if (!KernelRegistry::GetInstance()->HasKernelRegistered()) return RET_OK;
if (!KernelRegistry::GetInstance()->CheckRegistered(schema::PrimitiveType_Custom)) {

View File

@ -44,10 +44,12 @@ class Generator {
protected:
virtual int CodeNetHFile() = 0;
virtual int CodeNetCFile() = 0;
virtual void CodeNetExecuteFunc(std::ofstream &ofs) = 0;
virtual int CodeWeightFile();
virtual int CodeRegKernelHFile();
void CodeNetRunFunc(std::ofstream &ofs);
void CodeCommonNetH(std::ofstream &ofs);
void CodeCommonNetC(std::ofstream &ofs);
Configurator *config_{nullptr};
std::unique_ptr<CoderContext> ctx_{nullptr};

View File

@ -15,29 +15,40 @@
*/
#include "coder/generator/inference/inference_generator.h"
#include <vector>
#include <string>
#include "coder/generator/component/common_component.h"
#include "coder/generator/component/weight_component.h"
#include "coder/generator/component/const_blocks/license.h"
#include "coder/generator/component/component.h"
#include "coder/opcoders/parallel.h"
namespace mindspore::lite::micro {
void InferenceGenerator::CodeNetExecuteFunc(std::ofstream &ofs) {
ofs << "void Execute(bool train_mode) {\n";
if (config_->support_parallel()) {
ofs << " " << gThreadNum << " = GetCurrentThreadNum();\n";
ofs << " SetSpinCountMaxValue();\n";
}
for (const auto &block : ctx_->code_blocks()) {
ofs << " {\n" << block << " }\n";
}
for (const auto &block : ctx_->after_inference_code_blocks()) {
ofs << block << "\n";
}
if (config_->support_parallel()) {
ofs << " SetSpinCountMinValue();\n";
}
ofs << "}\n";
}
int InferenceGenerator::CodeNetHFile() {
std::string net_include_file = net_src_file_path_ + net_inc_hfile_;
std::ofstream ofs(net_include_file);
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
MS_LOG(INFO) << "write " << net_include_file;
ofs << g_hwLicense;
ofs << kExternCpp;
CodeInputState(ofs);
CodeCommonNetH(ofs);
CodeCopyOutputsState(ofs);
if (is_get_quant_args_) {
CodeGraphQuantArgsState(ofs);
}
CodeManageResourceState(ofs);
CodeInferenceState(ofs);
ofs << kEndExternCpp;
ofs.close();
return RET_OK;
}
@ -46,24 +57,9 @@ int InferenceGenerator::CodeNetCFile() {
std::ofstream ofs(net_impl_file);
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
MS_LOG(INFO) << "write " << net_impl_file;
ofs << g_hwLicense << "\n"
<< "#include \"" << net_weight_hfile_ << "\"\n"
<< "#include \"" << net_inc_hfile_ << "\"\n\n";
if (config_->support_parallel()) {
ofs << "#include \"" << kThreadWrapper << "\"\n\n";
}
if (config_->debug_mode()) {
ofs << "#include \"" << kDebugUtils << "\"\n";
}
CodeGlobalCodeBlocks(ofs, ctx_);
CodeInputImplement(ofs, ctx_);
CodeCommonNetC(ofs);
CodeCopyOutputsImplement(ofs, ctx_);
CodeInitResourceImplement(ofs, ctx_);
CodeFreeResourceImplement(ofs, ctx_, *config_);
if (is_get_quant_args_) {
CodeGraphQuantArgsImplement(ofs, ctx_);
}
CodeNetRunFunc(ofs);
CodeNetExecuteFunc(ofs);
ofs.close();
return RET_OK;
}

View File

@ -28,6 +28,7 @@ class InferenceGenerator : public Generator {
~InferenceGenerator() override = default;
private:
void CodeNetExecuteFunc(std::ofstream &ofs) override;
int CodeNetHFile() override;
int CodeNetCFile() override;
};

View File

@ -1,73 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "coder/generator/train/train_generator.h"
#include <vector>
#include <string>
#include "coder/generator/component/common_component.h"
#include "coder/generator/component/weight_component.h"
#include "coder/generator/component/train_component.h"
#include "coder/generator/component/const_blocks/license.h"
namespace mindspore::lite::micro {
void TrainGenerator::CodeGradientFunc(std::ofstream &ofs) const {
ofs << "float ComputeLossAndGradient() {\n";
ofs << " float loss = 0;\n";
for (const auto &block : ctx_->train_blocks()) {
ofs << "\t{\n" << block << "\t}\n";
}
ofs << " return loss;\n";
ofs << "}\n";
}
int TrainGenerator::CodeNetHFile() {
std::string net_include_file = net_src_file_path_ + net_inc_hfile_;
std::ofstream ofs(net_include_file);
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
MS_LOG(INFO) << "write " << net_include_file;
ofs << g_hwLicense;
if (config_->code_mode() == CodeMode::Inference) {
ofs << "#include \"src/runtime/thread_pool.h\"\n";
}
ofs << "#include \"microtensor.h\"\n\n";
CodeTrainParams(ofs);
CodeInputState(ofs);
if (config_->target() != kCortex_M) {
CodeInitWeightState(ofs);
}
CodeManageResourceState(ofs);
CodeInferenceState(ofs);
CodeFeaturesState(ofs);
CodeTrainState(ofs);
return RET_OK;
}
int TrainGenerator::CodeNetCFile() {
std::string net_impl_file = net_src_file_path_ + net_src_cfile_;
std::ofstream ofs(net_impl_file);
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
MS_LOG(INFO) << "write " << net_impl_file;
CodeInputImplement(ofs, ctx_);
CodeInitResourceImplement(ofs, ctx_);
CodeFreeResourceImplement(ofs, ctx_, *config_);
CodeFeaturesImplement(ofs, ctx_);
CodeNetRunFunc(ofs);
CodeGradientFunc(ofs);
CodeTrainImplement(ofs, ctx_);
ofs.close();
return RET_OK;
}
} // namespace mindspore::lite::micro

View File

@ -24,6 +24,7 @@
#include "coder/log.h"
#include "coder/opcoders/op_coder_register.h"
#include "coder/utils/type_cast.h"
#include "coder/utils/train_utils.h"
#include "schema/inner/model_generated.h"
#include "securec/include/securec.h"
#include "src/common/prim_util.h"
@ -183,7 +184,8 @@ int CoderGraph::InitGraphInOutTensors() {
}
}
SetOutputIndices(output_indices);
InitInputs();
int ret = InitInputs();
MS_CHECK_RET_CODE(ret, "init graph input tensors failed.");
InitOutputs();
return RET_OK;
}
@ -192,23 +194,90 @@ std::vector<lite::Tensor *> CoderGraph::input_tensors() const { return input_ten
std::vector<lite::Tensor *> CoderGraph::output_tensors() const { return output_tensors_; }
void CoderGraph::InitInputs() {
for (const auto &pair : inputs_map_) {
std::vector<Tensor *> tensors = pair.second;
input_tensors_.insert(input_tensors_.end(), tensors.begin(), tensors.end());
}
// remove duplicate tensors
std::set<lite::Tensor *> unique;
unique.insert(input_tensors_.begin(), input_tensors_.end());
std::vector<lite::Tensor *> CoderGraph::eval_output_tensors() const { return eval_output_tensors_; }
std::vector<lite::Tensor *> CoderGraph::train_output_tensors() const { return train_output_tensors_; }
int CoderGraph::InitInputs() {
input_tensors_.clear();
input_tensors_.insert(input_tensors_.end(), unique.begin(), unique.end());
auto graph_in_size = model_->graph_.input_indices_.size();
for (size_t i = 0; i < graph_in_size; i++) {
auto in_tensor_idx = model_->graph_.input_indices_[i];
MS_CHECK_TRUE_MSG(in_tensor_idx < all_tensors_.size(), RET_ERROR, "in tensor idx is out of range.");
auto in_tensor = all_tensors_.at(in_tensor_idx);
MS_CHECK_TRUE_MSG(in_tensor != nullptr, RET_ERROR, "in_tensor is nullptr.");
input_tensors_.emplace_back(in_tensor);
}
return RET_OK;
}
void CoderGraph::InitOutputs() {
std::transform(output_indices_.begin(), output_indices_.end(), std::back_inserter(output_tensors_),
output_tensors_.clear();
(void)std::transform(output_indices_.begin(), output_indices_.end(), std::back_inserter(output_tensors_),
[&](uint32_t a) { return this->all_tensors_.at(a); });
}
int CoderGraph::CompileTrainOutputs(const std::vector<OperatorCoder *> &train_coders) {
train_outputs_map_.clear();
train_output_tensors_.clear();
for (auto train_coder : train_coders) {
MS_CHECK_TRUE_MSG(train_coder != nullptr, RET_ERROR, "train coder is nullptr.");
if (outputs_map_.find(train_coder->name()) == outputs_map_.end() || IsMaskOutput(train_coder) ||
train_outputs_map_.find(train_coder->name()) != train_outputs_map_.end()) { // filter optimizer out tensors out
continue;
}
MS_CHECK_TRUE_MSG(!train_coder->output_tensors().empty(), RET_ERROR, "output tensors is empty.");
auto ms_tensor = train_coder->output_tensors().at(0);
if (ms_tensor != nullptr) {
train_outputs_map_[train_coder->name()].emplace_back(ms_tensor);
train_output_tensors_.emplace_back(ms_tensor);
}
}
if (train_outputs_map_.empty()) {
train_outputs_map_ = outputs_map_;
}
if (train_output_tensors_.empty()) {
train_output_tensors_ = output_tensors_;
}
return RET_OK;
}
int CoderGraph::CompileEvalOutputs(const std::vector<OperatorCoder *> &train_coders) {
eval_outputs_map_.clear();
eval_output_tensors_.clear();
for (auto coder : train_coders) {
MS_CHECK_TRUE_MSG(coder != nullptr, RET_ERROR, "coder is nullptr.");
if (!IsLossCoder(coder) || IsGradCoder(coder)) {
continue;
}
for (auto in_coder : coder->input_ops()) {
if (IsLossCoder(in_coder) || IsGradCoder(in_coder)) {
continue;
}
auto in_in_coders = in_coder->input_ops();
bool is_loss = std::any_of(in_in_coders.begin(), in_in_coders.end(),
[](const OperatorCoder *coder) { return IsLossCoder(coder); });
if (is_loss || eval_outputs_map_.find(in_coder->name()) != eval_outputs_map_.end()) {
continue;
}
MS_CHECK_TRUE_MSG(!in_coder->output_tensors().empty(), RET_ERROR, "output tensors is empty.");
auto ms_tensor = in_coder->output_tensors().at(0);
if (ms_tensor != nullptr) {
ms_tensor->set_init_ref_count(ms_tensor->init_ref_count() + 1);
eval_outputs_map_[in_coder->name()].emplace_back(ms_tensor);
eval_output_tensors_.emplace_back(ms_tensor);
}
}
}
if (eval_outputs_map_.empty()) {
eval_outputs_map_ = outputs_map_;
}
if (eval_output_tensors_.empty()) {
eval_output_tensors_ = output_tensors_;
}
return RET_OK;
}
void CoderGraph::SetAllTensors(const std::vector<Tensor *> &all_tensors) {
all_tensors_.insert(all_tensors_.end(), all_tensors.begin(), all_tensors.end());
}
@ -241,6 +310,8 @@ std::vector<lite::Tensor *> CoderGraph::all_tensors() const { return this->all_t
const std::map<std::string, std::vector<lite::Tensor *>> &CoderGraph::GetOutputsMap() const { return outputs_map_; }
const std::map<std::string, std::vector<Tensor *>> &CoderGraph::GetEvalOutputsMap() const { return eval_outputs_map_; }
std::vector<uint32_t> CoderGraph::input_indices() const { return this->input_indices_; }
std::vector<uint32_t> CoderGraph::output_indices() const { return this->output_indices_; }

View File

@ -22,6 +22,7 @@
#include <unordered_map>
#include <vector>
#include <string>
#include "tools/converter/micro/coder/opcoders/op_coder.h"
#include "tools/converter/micro/coder/config.h"
#include "include/context.h"
#include "include/model.h"
@ -40,9 +41,13 @@ class CoderGraph {
void SetAllTensors(const std::vector<Tensor *> &all_tensors);
void InitInputs();
int InitInputs();
void InitOutputs();
int CompileTrainOutputs(const std::vector<OperatorCoder *> &train_coders);
int CompileEvalOutputs(const std::vector<OperatorCoder *> &train_coders);
void SetInputIndices(const std::vector<uint32_t> &input_indices);
void SetOutputIndices(const std::vector<uint32_t> &output_indices);
@ -59,10 +64,14 @@ class CoderGraph {
std::vector<Tensor *> output_tensors() const;
std::vector<Tensor *> eval_output_tensors() const;
std::vector<Tensor *> train_output_tensors() const;
std::vector<Tensor *> all_tensors() const;
const std::map<NODE_ID, std::vector<Tensor *>> &GetOutputsMap() const;
const std::map<std::string, std::vector<Tensor *>> &GetEvalOutputsMap() const;
const Model *model() const { return this->model_; }
void DumpUnSupportLayer(Target target);
@ -72,9 +81,13 @@ class CoderGraph {
// others are parameter_node
std::vector<Tensor *> all_tensors_;
std::vector<Tensor *> input_tensors_;
std::vector<Tensor *> input_tensors_; // graph origin inputs
std::vector<Tensor *> output_tensors_;
std::vector<Tensor *> output_tensors_; // graph origin outputs
std::vector<Tensor *> eval_output_tensors_; // graph outputs in Eval mode
std::vector<Tensor *> train_output_tensors_; // graph outputs in Train mode
std::vector<uint32_t> input_indices_;
@ -82,7 +95,12 @@ class CoderGraph {
std::map<std::string, std::vector<Tensor *>> inputs_map_;
std::map<std::string, std::vector<Tensor *>> outputs_map_;
std::map<std::string, std::vector<Tensor *>> outputs_map_; // graph origin outputs tensor map
// <node name, graph output tensors>
std::map<std::string, std::vector<Tensor *>> eval_outputs_map_; // graph eval outputs tensor map
std::map<std::string, std::vector<Tensor *>> train_outputs_map_; // graph train outputs tensor map
Model *model_{nullptr};
};

View File

@ -21,10 +21,10 @@
#include "nnacl/int8/quantize.h"
#include "coder/log.h"
#include "src/litert/tensor_category.h"
#include "src/common/quant_utils.h"
namespace mindspore::lite::micro {
namespace {
constexpr int kRoundUp = 2;
constexpr int kPerTensor = 1;
} // namespace
Conv2DBaseCoder::~Conv2DBaseCoder() {
FreeConvQuantParams();

View File

@ -52,7 +52,8 @@ int MatMulFP32BaseCoder::InitBiasData() {
is_bias_broadcast_ = true;
}
ori_bias_pack_ptr_size_ = bias_tensor_->ElementsNum() * sizeof(float);
bias_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight));
bias_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight,
bias_tensor_->tensor_name() + "_online_pack"));
MS_CHECK_PTR(bias_ptr_);
}
return RET_OK;
@ -81,7 +82,8 @@ int MatMulFP32BaseCoder::InitBufferA() {
}
a_pack_ptr_size_ = static_cast<size_t>(params_->batch * params_->row_align_ * params_->deep_ * sizeof(float));
if (params_->a_const_) {
a_pack_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight));
a_pack_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight,
input_tensors_.at(0)->tensor_name() + "_online_pack"));
} else {
a_pack_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, a_pack_ptr_size_, kWorkspace));
}
@ -95,7 +97,8 @@ int MatMulFP32BaseCoder::InitBufferB() {
}
b_pack_ptr_size_ = static_cast<size_t>(params_->batch * params_->col_align_ * params_->deep_ * sizeof(float));
if (params_->b_const_) {
b_pack_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight));
b_pack_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight,
input_tensors_.at(1)->tensor_name() + "_online_pack"));
} else {
b_pack_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, b_pack_ptr_size_, kWorkspace));
}
@ -244,7 +247,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
if (vec_matmul_) {
code << " const float *batch_a_ptr = " << a_pack_str << " + i * " << params_->deep_ << ";\n";
code << " const float *batch_b_ptr = " << b_pack_str << " + i * " << params_->deep_ * params_->col_ << ";\n";
code << " float *batch_c_ptr = " << c_str << " + i * " << params_->row_ * params_->col_ << ";\n";
code << " float *batch_c_ptr = " << c_str << " + i * " << params_->row_ * params_->col_ << ";\n ";
code.CodeFunction("MatVecMulFp32", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_, params_->act_type_,
params_->deep_, cur_oc);
@ -253,7 +256,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
<< ";\n";
code << " const float *batch_b_ptr = " << b_pack_str << " + i * " << params_->deep_ * params_->col_align_
<< ";\n";
code << " float *batch_c_ptr = " << c_str << " + i * " << params_->row_ * params_->col_ << ";\n";
code << " float *batch_c_ptr = " << c_str << " + i * " << params_->row_ * params_->col_ << ";\n ";
code.CodeFunction("MatMulOpt", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_, params_->act_type_,
params_->deep_, params_->row_, cur_oc, params_->col_, "OutType_Nhwc");

View File

@ -21,7 +21,6 @@
#include <string>
#include <memory>
#include "tools/converter/micro/coder/context.h"
#include "tools/converter/micro/coder/graph.h"
#include "tools/converter/micro/coder/allocator/allocator.h"
#include "include/errorcode.h"
#include "src/litert/kernel_exec.h"
@ -58,6 +57,8 @@ class OperatorCoder {
const std::vector<Tensor *> input_tensors() const;
const std::vector<Tensor *> output_tensors() const;
void SetInputOps(const std::vector<OperatorCoder *> &input_ops) { input_ops_ = input_ops; }
void SetOutputOps(const std::vector<OperatorCoder *> &output_ops) { output_ops_ = output_ops; }
void AddInputOp(OperatorCoder *op) { input_ops_.push_back(op); }
void AddOutputOp(OperatorCoder *op) { output_ops_.push_back(op); }
const std::vector<OperatorCoder *> input_ops() const { return input_ops_; }

View File

@ -19,19 +19,14 @@
#include <vector>
#include <utility>
#include "coder/context.h"
#include "coder/train.h"
#include "coder/allocator/allocator.h"
#include "coder/generator/generator.h"
#include "coder/generator/inference/inference_generator.h"
#include "coder/generator/train/train_generator.h"
#include "coder/opcoders/op_coder_builder.h"
#include "coder/opcoders/kernel_registry.h"
#include "coder/utils/coder_utils.h"
#include "coder/log.h"
#include "src/common/ops/populate/populate_register.h"
#include "src/common/version_manager.h"
#include "src/litert/infer_manager.h"
#include "src/litert/scheduler.h"
#include "src/litert/lite_model.h"
#include "include/errorcode.h"
#include "include/model.h"
@ -42,8 +37,7 @@
namespace mindspore::lite::micro {
CoderSession::CoderSession() { allocator_ = MemoryAllocator::GetInstance(); }
int CoderSession::EndCode() {
int ret = RET_OK;
int CoderSession::PassArgsToContext() {
context_->set_tensor_map(allocator_->tensors_map());
context_->set_saved_weights(allocator_->saved_weights());
size_t de_quant_max_workspace_size = nnacl::Dequant::GetInstance()->de_quant_max_workspace();
@ -53,26 +47,21 @@ int CoderSession::EndCode() {
context_->set_total_buffer_size(final_total_size);
context_->set_graph_inputs(coder_graph_->input_tensors());
context_->set_graph_outputs(coder_graph_->output_tensors());
Configurator *config = Configurator::GetInstance();
if (config->debug_mode()) {
if (Configurator::GetInstance()->debug_mode()) {
std::vector<std::string> blocks;
blocks = AddDumpDataInfo(context_->code_blocks(), op_coders_);
context_->set_code_blocks(blocks);
}
if (config->code_mode() == Train) {
ret = Train::TransformGraphForTrain(context_.get(), op_coders_, schema_version_);
MS_CHECK_RET_CODE(ret, "transform graph for train failed.");
}
return ret;
return RET_OK;
}
int CoderSession::Run() {
MS_LOG(INFO) << "start run opcoders";
// 1. assign memory
int CoderSession::Preprocess() {
// assign memory
std::vector<lite::Tensor *> inputs = coder_graph_->input_tensors();
int ret = allocator_->Assign(inputs, op_coders_);
MS_CHECK_RET_CODE(ret, "assign memory failed");
// 2. prepare, init model parameters
// prepare, init model parameters
for (const auto &op_coder : op_coders_) {
MS_CHECK_PTR(op_coder);
MS_LOG(DEBUG) << "prepare: " << op_coder->name();
@ -80,41 +69,40 @@ int CoderSession::Run() {
MS_CHECK_RET_CODE(ret, "prepare coder " << op_coder->name() << " failed");
allocator_->enable_is_next();
}
// 3. docode, write operator code
return RET_OK;
}
int CoderSession::DoCode() {
int ret = RET_OK;
for (const auto &op_coder : op_coders_) {
MS_CHECK_PTR(op_coder);
MS_LOG(DEBUG) << "code: " << op_coder->name();
ret = op_coder->DoCode(this->context_.get());
MS_CHECK_RET_CODE(ret, "do coder " << op_coder->name() << " failed");
}
return ret;
}
int CoderSession::Run() {
MS_LOG(INFO) << "start run opcoders";
ret = this->EndCode();
MS_CHECK_RET_CODE(ret, "End code failed.");
int ret = Preprocess();
MS_CHECK_RET_CODE(ret, "preprocess failed");
ret = DoCode();
MS_CHECK_RET_CODE(ret, "do code failed");
(void)PassArgsToContext();
MS_LOG(INFO) << "run opcoders success";
return RET_OK;
}
int CoderSession::GenerateCode() {
MS_LOG(INFO) << "CoderSession::GenerateCode start";
std::shared_ptr<Generator> generator;
Configurator *config = Configurator::GetInstance();
CodeMode code_mode = config->code_mode();
switch (code_mode) {
case Inference:
MS_LOG(INFO) << "generate code for Inference";
generator = std::make_shared<InferenceGenerator>(std::move(context_));
break;
case Train:
MS_LOG(INFO) << "generate code for Train";
generator = std::make_shared<TrainGenerator>(std::move(context_));
break;
default:
MS_LOG(ERROR) << "unsupported generator code mode, " << code_mode;
return RET_ERROR;
}
auto generator = std::make_shared<InferenceGenerator>(std::move(context_));
MS_CHECK_PTR(generator);
// when use file, coder context need to remove initial parameters from tensors info
// we use tmp_tensor_list to storage
MS_CHECK_PTR(generator);
int ret = generator->GenerateCode();
if (ret != RET_OK) {
MS_LOG(ERROR) << "generate code failed";
@ -146,29 +134,37 @@ int CoderSession::Build() {
}
int CoderSession::InitOpcodersInputsAndOutputs() {
std::map<Tensor *, OperatorCoder *> input_node_map;
std::map<Tensor *, OperatorCoder *> output_node_map;
std::map<Tensor *, OperatorCoder *> tensor_pre_coders; // a tensor is a certain coder's output
std::map<Tensor *, std::vector<OperatorCoder *>> tensor_post_coders; // a tensor is many coder's input
for (const auto &op_coder : op_coders_) {
std::vector<Tensor *> inputs = op_coder->input_tensors();
std::for_each(inputs.begin(), inputs.end(),
[&](Tensor *t) { input_node_map.insert(std::make_pair(t, op_coder.get())); });
std::vector<Tensor *> outputs = op_coder->input_tensors();
std::for_each(outputs.begin(), outputs.end(),
[&](Tensor *t) { output_node_map.insert(std::make_pair(t, op_coder.get())); });
for (auto *in_tensor : op_coder->input_tensors()) {
tensor_post_coders[in_tensor].emplace_back(op_coder.get());
}
for (auto *output_tensor : op_coder->output_tensors()) {
tensor_pre_coders[output_tensor] = op_coder.get();
}
}
for (const auto &op_coder : op_coders_) {
op_coder->SetInputOps({});
std::vector<Tensor *> inputs = op_coder->input_tensors();
for (const auto &tensor : inputs) {
auto item = output_node_map.find(tensor);
if (item != output_node_map.end()) {
auto item = tensor_pre_coders.find(tensor);
if (item != tensor_pre_coders.end() && item->second != op_coder.get()) {
op_coder->AddInputOp(item->second);
}
}
op_coder->SetOutputOps({});
std::vector<Tensor *> outputs = op_coder->output_tensors();
for (const auto &tensor : outputs) {
auto item = input_node_map.find(tensor);
if (item != input_node_map.end()) {
op_coder->AddOutputOp(item->second);
auto item = tensor_post_coders.find(tensor);
if (item != tensor_post_coders.end()) {
for (auto *find_coder : item->second) {
if (find_coder == op_coder.get()) {
continue;
}
op_coder->AddOutputOp(find_coder);
}
}
}
}
@ -290,7 +286,7 @@ int CoderSession::CreateOpCoders() {
op_coders_.push_back(std::move(op_coder));
builder.Reset();
}
InitOpcodersInputsAndOutputs();
(void)InitOpcodersInputsAndOutputs();
return RET_OK;
}
@ -307,11 +303,5 @@ int CoderSession::CompileGraph() {
MS_CHECK_RET_CODE(InitTensorsRef(), "InitTensorsRefcount failed!");
return RET_OK;
}
std::shared_ptr<CoderSession> CreateCoderSession() {
auto session = std::make_shared<CoderSession>();
return session;
}
CoderSession::~CoderSession() { allocator_->Free(); }
} // namespace mindspore::lite::micro

View File

@ -36,11 +36,16 @@ class CoderSession {
int Init(const void *content, int size);
int Build();
virtual int Build();
int Run();
virtual int Run();
int GenerateCode();
virtual int GenerateCode();
protected:
int Preprocess();
virtual int DoCode();
virtual int PassArgsToContext();
private:
OpParameter *GenParameterAndInfer(const LiteGraph::Node *node, const std::vector<lite::Tensor *> &inputs,
@ -50,15 +55,15 @@ class CoderSession {
int CreateOpCoders();
int InitCodeGraph();
int CompileGraph();
int EndCode();
protected:
std::vector<std::unique_ptr<OperatorCoder>> op_coders_;
std::unique_ptr<CoderGraph> coder_graph_{nullptr};
std::unique_ptr<CoderContext> context_{nullptr};
MemoryAllocator *allocator_{nullptr};
std::vector<std::unique_ptr<OperatorCoder>> op_coders_;
private:
int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR;
};
std::shared_ptr<CoderSession> CreateCoderSession();
} // namespace mindspore::lite::micro
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_SESSION_H_

View File

@ -0,0 +1,89 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "coder/train/train_generator.h"
#include <string>
#include "coder/generator/component/train_component.h"
#include "coder/opcoders/parallel.h"
#include "coder/generator/component/component.h"
#include "tools/common/string_util.h"
namespace mindspore::lite::micro {
void TrainGenerator::CodeTrainAndEvalFunc(std::ofstream &ofs) {
size_t i = 0;
size_t code_blocks_size = code_blocks_with_flag_.size();
while (i < code_blocks_size) {
bool is_train_only = code_blocks_with_flag_.at(i).second;
if (!is_train_only) {
ofs << " {\n" << code_blocks_with_flag_.at(i).first << " }\n";
i++;
continue;
}
size_t j = i;
while (j < code_blocks_size && code_blocks_with_flag_.at(j).second) { // is loss or grad op
j++;
}
ofs << " if (train_mode) {\n";
for (; i < j; i++) {
auto code_block = code_blocks_with_flag_.at(i).first;
(void)FindAndReplaceAll(&code_block, " ", " ");
ofs << " {\n" << code_block << " }\n";
}
ofs << " }\n";
}
}
void TrainGenerator::CodeNetExecuteFunc(std::ofstream &ofs) {
ofs << "void Execute(bool train_mode) {\n";
if (config_->support_parallel()) {
ofs << " " << gThreadNum << " = GetCurrentThreadNum();\n";
ofs << " SetSpinCountMaxValue();\n";
}
CodeTrainAndEvalFunc(ofs);
if (config_->support_parallel()) {
ofs << " SetSpinCountMinValue();\n";
}
ofs << "}\n";
}
int TrainGenerator::CodeNetHFile() {
std::string net_include_file = net_src_file_path_ + net_inc_hfile_;
std::ofstream ofs(net_include_file);
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
MS_LOG(INFO) << "write " << net_include_file;
CodeCommonNetH(ofs);
CodeCopyTrainOutputsState(ofs);
ofs << kEndExternCpp;
ofs.close();
return RET_OK;
}
int TrainGenerator::CodeNetCFile() {
std::string net_impl_file = net_src_file_path_ + net_src_cfile_;
std::ofstream ofs(net_impl_file);
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
MS_LOG(INFO) << "write " << net_impl_file;
CodeCommonNetC(ofs);
CodeCopyTrainOutputsImplement(ofs, ctx_);
CodeNetExecuteFunc(ofs);
ofs.close();
return RET_OK;
}
} // namespace mindspore::lite::micro

View File

@ -14,23 +14,30 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_GENERATOR_TRAIN_TRAIN_GENERATOR_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_GENERATOR_TRAIN_TRAIN_GENERATOR_H_
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_TRAIN_TRAIN_GENERATOR_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_TRAIN_TRAIN_GENERATOR_H_
#include <utility>
#include <memory>
#include <string>
#include <vector>
#include "tools/converter/micro/coder/generator/generator.h"
namespace mindspore::lite::micro {
class TrainGenerator : public Generator {
public:
explicit TrainGenerator(std::unique_ptr<CoderContext> ctx) : Generator(std::move(ctx)) {}
TrainGenerator(std::unique_ptr<CoderContext> ctx, std::vector<std::pair<std::string, bool>> code_blocks_with_flag)
: Generator(std::move(ctx)), code_blocks_with_flag_(std::move(code_blocks_with_flag)) {}
~TrainGenerator() override = default;
private:
void CodeTrainAndEvalFunc(std::ofstream &ofs);
void CodeNetExecuteFunc(std::ofstream &ofs) override;
int CodeNetHFile() override;
int CodeNetCFile() override;
void CodeGradientFunc(std::ofstream &ofs) const;
private:
std::vector<std::pair<std::string, bool>> code_blocks_with_flag_; // <code block, is op only in train mode>
};
} // namespace mindspore::lite::micro
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_GENERATOR_TRAIN_TRAIN_GENERATOR_H_
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_TRAIN_TRAIN_GENERATOR_H_

View File

@ -0,0 +1,142 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/micro/coder/train/train_session.h"
#include <map>
#include <utility>
#include <string>
#include <vector>
#include <algorithm>
#include <memory>
#include "include/errorcode.h"
#include "tools/converter/micro/coder/utils/train_utils.h"
#include "tools/converter/micro/coder/train/train_generator.h"
namespace mindspore::lite::micro {
int CoderTrainSession::Build() {
int ret = CoderSession::Build();
MS_CHECK_RET_CODE(ret, "code session build failed.");
MS_CHECK_RET_CODE(CompileTrainCoders(), "CompileTrainCoders failed");
MS_CHECK_RET_CODE(coder_graph_->CompileTrainOutputs(train_op_coders_), "CompileTrainOutputs failed!");
MS_CHECK_RET_CODE(coder_graph_->CompileEvalOutputs(train_op_coders_), "CompileEvalOutputs failed!");
MS_CHECK_RET_CODE(CompileEvalCoders(coder_graph_->GetEvalOutputsMap()), "CompileTrainCoders failed.");
return RET_OK;
}
int CoderTrainSession::Run() {
MS_LOG(INFO) << "start run op coders";
int ret = Preprocess();
MS_CHECK_RET_CODE(ret, "preprocess failed");
ret = DoCode();
MS_CHECK_RET_CODE(ret, "do code failed");
PassArgsToContext();
MS_LOG(INFO) << "run op coders success";
return RET_OK;
}
int CoderTrainSession::GenerateCode() {
MS_LOG(INFO) << "CoderSession::GenerateCode start";
auto generator = std::make_shared<TrainGenerator>(std::move(context_), code_blocks_with_flag_);
MS_CHECK_PTR(generator);
int ret = generator->GenerateCode();
if (ret != RET_OK) {
MS_LOG(ERROR) << "generate code failed";
}
MS_LOG(INFO) << "CoderSession::GenerateCode done";
return ret;
}
int CoderTrainSession::DoCode() {
int ret = RET_OK;
size_t last_idx = context_->code_blocks().size();
for (const auto &op_coder : op_coders_) {
MS_CHECK_PTR(op_coder);
MS_LOG(DEBUG) << "code: " << op_coder->name();
ret = op_coder->DoCode(this->context_.get());
MS_CHECK_RET_CODE(ret, "do coder " << op_coder->name() << " failed");
auto code_blocks = context_->code_blocks();
auto cur_indx = code_blocks.size();
MS_CHECK_TRUE_MSG(cur_indx > last_idx, RET_ERROR, "append code failed.");
bool is_train_only =
std::find(eval_op_coders_.begin(), eval_op_coders_.end(), op_coder.get()) == eval_op_coders_.end();
for (; last_idx < cur_indx; last_idx++) {
code_blocks_with_flag_.emplace_back(code_blocks.at(last_idx), is_train_only);
}
}
return ret;
}
int CoderTrainSession::UpdateCodeBlocksWithFlag() {
auto code_blocks = context_->code_blocks();
MS_CHECK_TRUE_MSG(code_blocks.size() == code_blocks_with_flag_.size(), RET_ERROR, "code blocks size is unmatched.");
for (size_t i = 0; i < code_blocks.size(); i++) {
code_blocks_with_flag_.at(i).first = code_blocks.at(i);
}
return RET_OK;
}
int CoderTrainSession::PassArgsToContext() {
int ret = RET_OK;
(void)CoderSession::PassArgsToContext();
if (Configurator::GetInstance()->debug_mode()) {
ret = UpdateCodeBlocksWithFlag();
MS_CHECK_RET_CODE(ret, "update code_blocks_with_flag_ failed.");
}
context_->set_graph_train_outputs(coder_graph_->train_output_tensors());
context_->set_graph_eval_outputs(coder_graph_->eval_output_tensors());
return ret;
}
void CoderTrainSession::FindEvalCoders(OperatorCoder *coder) {
if (coder == nullptr) {
return;
}
if (std::find(eval_op_coders_.begin(), eval_op_coders_.end(), coder) ==
eval_op_coders_.end()) { // kernel is not already in vector
for (auto in_coder : coder->input_ops()) {
FindEvalCoders(in_coder);
}
if (!IsLossCoder(coder)) {
eval_op_coders_.emplace_back(coder);
}
}
}
int CoderTrainSession::CompileTrainCoders() {
train_op_coders_.clear();
(void)std::transform(op_coders_.begin(), op_coders_.end(), std::back_inserter(train_op_coders_),
[](const std::unique_ptr<OperatorCoder> &coder) { return coder.get(); });
return RET_OK;
}
int CoderTrainSession::CompileEvalCoders(const std::map<std::string, std::vector<Tensor *>> &eval_outputs_map) {
eval_op_coders_.clear();
for (const auto &item : eval_outputs_map) {
std::string kernel_name = item.first;
auto iter = std::find_if(train_op_coders_.begin(), train_op_coders_.end(),
[&kernel_name](const OperatorCoder *coder) { return (coder->name() == kernel_name); });
MS_CHECK_TRUE_MSG(iter != train_op_coders_.end(), RET_ERROR, "can't find output coder in Eval mode.");
MS_CHECK_TRUE_MSG(*iter != nullptr, RET_ERROR, "find output coder in Eval mode.");
(void)FindEvalCoders(*iter);
}
if (eval_op_coders_.empty()) {
eval_op_coders_ = train_op_coders_;
}
return RET_OK;
}
} // namespace mindspore::lite::micro

View File

@ -0,0 +1,48 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_TRAIN_TRAIN_SESSION_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_TRAIN_TRAIN_SESSION_H_
#include <map>
#include <utility>
#include <string>
#include <vector>
#include "tools/converter/micro/coder/session.h"
namespace mindspore::lite::micro {
class CoderTrainSession : public CoderSession {
public:
int Build() override;
int Run() override;
int GenerateCode() override;
private:
int DoCode() override;
int UpdateCodeBlocksWithFlag();
int PassArgsToContext() override;
void FindEvalCoders(OperatorCoder *coder);
int CompileTrainCoders();
int CompileEvalCoders(const std::map<std::string, std::vector<Tensor *>> &eval_outputs_map);
private:
std::vector<std::pair<std::string, bool>> code_blocks_with_flag_; // <code block, is op only in train mode>
std::vector<OperatorCoder *> train_op_coders_;
std::vector<OperatorCoder *> eval_op_coders_;
};
} // namespace mindspore::lite::micro
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_TRAIN_TRAIN_SESSION_H_

View File

@ -14,11 +14,11 @@
* limitations under the License.
*/
#include "coder/utils/coder_utils.h"
#include <set>
#include <queue>
#include <string>
#include <memory>
#include <fstream>
#include "src/common/prim_util.h"
#include "tools/converter/micro/coder/log.h"
#include "tools/converter/micro/coder/utils/type_cast.h"
#include "tools/converter/micro/coder/allocator/allocator.h"

View File

@ -0,0 +1,60 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/train_utils.h"
#include <string>
#include <vector>
#include "src/common/prim_util.h"
namespace mindspore::lite::micro {
namespace {
constexpr char kGradName[] = "Gradients";
}
bool IsLossCoder(const OperatorCoder *coder) {
MS_CHECK_TRUE_MSG(coder != nullptr, false, "coder is nullptr");
bool is_loss = false;
const std::vector<std::string> loss_names = {"loss_fct", "_loss_fn", "SigmoidCrossEntropy"};
for (auto &name : loss_names) {
if (coder->name().find(name) != std::string::npos) {
is_loss = true;
break;
}
}
return is_loss;
}
bool IsGradCoder(const OperatorCoder *coder) {
MS_CHECK_TRUE_MSG(coder != nullptr, false, "coder is nullptr");
return coder->name().find(kGradName) != std::string::npos;
}
bool IsOptimizer(const OperatorCoder *coder) {
MS_CHECK_TRUE_MSG(coder != nullptr, false, "coder is nullptr");
auto node = coder->node();
MS_CHECK_TRUE_MSG(node != nullptr, false, "coder's node is nullptr");
auto node_type = static_cast<schema::PrimitiveType>(GetPrimitiveType(node->primitive_));
return (node_type == schema::PrimitiveType_Adam) || (node_type == schema::PrimitiveType_SGD) ||
(node_type == schema::PrimitiveType_ApplyMomentum);
}
bool IsMaskOutput(const OperatorCoder *coder) {
MS_CHECK_TRUE_MSG(coder != nullptr, false, "coder is nullptr");
auto node = coder->node();
MS_CHECK_TRUE_MSG(node != nullptr, false, "coder's node is nullptr");
auto node_type = static_cast<schema::PrimitiveType>(GetPrimitiveType(node->primitive_));
return (IsOptimizer(coder) || (node_type == schema::PrimitiveType_Assign));
}
} // namespace mindspore::lite::micro

View File

@ -0,0 +1,28 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_UTILS_TRAIN_UTILS_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_UTILS_TRAIN_UTILS_H_
#include "coder/opcoders/op_coder.h"
namespace mindspore::lite::micro {
bool IsLossCoder(const OperatorCoder *code);
bool IsGradCoder(const OperatorCoder *coder);
bool IsOptimizer(const OperatorCoder *coder);
bool IsMaskOutput(const OperatorCoder *coder);
} // namespace mindspore::lite::micro
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_UTILS_TRAIN_UTILS_H_