diff --git a/mindspore/lite/micro/coder/coder.cc b/mindspore/lite/micro/coder/coder.cc index 1d54bf9cbcf..1317d7b9f12 100644 --- a/mindspore/lite/micro/coder/coder.cc +++ b/mindspore/lite/micro/coder/coder.cc @@ -37,7 +37,6 @@ class CoderFlags : public virtual FlagParser { CoderFlags() { AddFlag(&CoderFlags::model_path_, "modelPath", "Input model path", ""); AddFlag(&CoderFlags::code_path_, "codePath", "Input code path", "."); - AddFlag(&CoderFlags::code_module_name_, "moduleName", "Input code module name", ""); AddFlag(&CoderFlags::target_, "target", "generated code target, x86| ARM32M| ARM32A| ARM64", "x86"); AddFlag(&CoderFlags::code_mode_, "codeMode", "generated code mode, Inference | Train", "Inference"); AddFlag(&CoderFlags::support_parallel_, "supportParallel", "whether support parallel launch, true | false", false); @@ -48,7 +47,6 @@ class CoderFlags : public virtual FlagParser { std::string model_path_; bool support_parallel_{false}; - std::string code_module_name_; std::string code_path_; std::string code_mode_; bool debug_mode_{false}; @@ -84,6 +82,27 @@ int Coder::Run(const std::string &model_path) { return status; } +int Configurator::ParseProjDir(std::string model_path) { + // split model_path to get model file name + proj_dir_ = model_path; + size_t found = proj_dir_.find_last_of("/\\"); + if (found != std::string::npos) { + proj_dir_ = proj_dir_.substr(found + 1); + } + found = proj_dir_.find(".ms"); + if (found != std::string::npos) { + proj_dir_ = proj_dir_.substr(0, found); + } else { + MS_LOG(ERROR) << "model file's name must be end with \".ms\"."; + return RET_ERROR; + } + if (proj_dir_.size() == 0) { + proj_dir_ = "net"; + MS_LOG(WARNING) << "parse model's name failed, use \"net\" instead."; + } + return RET_OK; +} + int Coder::Init(const CoderFlags &flags) const { static const std::map kTargetMap = { {"x86", kX86}, {"ARM32M", kARM32M}, {"ARM32A", kARM32A}, {"ARM64", kARM64}, {"All", kAllTargets}}; @@ -91,6 +110,17 @@ int Coder::Init(const CoderFlags &flags) const { Configurator *config = Configurator::GetInstance(); std::vector> parsers; + parsers.emplace_back([&flags, config]() -> bool { + if (!FileExists(flags.model_path_)) { + MS_LOG(ERROR) << "model_path \"" << flags.model_path_ << "\" is not valid"; + return false; + } + if (config->ParseProjDir(flags.model_path_) != RET_OK) { + return false; + } + return true; + }); + parsers.emplace_back([&flags, config]() -> bool { auto target_item = kTargetMap.find(flags.target_); MS_CHECK_TRUE_RET_BOOL(target_item != kTargetMap.end(), "unsupported target: " + flags.target_); @@ -119,20 +149,6 @@ int Coder::Init(const CoderFlags &flags) const { return true; }); - parsers.emplace_back([&flags, config]() -> bool { - if (!FileExists(flags.model_path_)) { - MS_LOG(ERROR) << "model_path \"" << flags.model_path_ << "\" is not valid"; - return false; - } - if (flags.code_module_name_.empty() || isdigit(flags.code_module_name_.at(0))) { - MS_LOG(ERROR) << "code_gen code module name " << flags.code_module_name_ - << " not valid: it must be given and the first char could not be number"; - return false; - } - config->set_module_name(flags.code_module_name_); - return true; - }); - parsers.emplace_back([&flags, config]() -> bool { const std::string slash = std::string(kSlash); if (!flags.code_path_.empty() && !DirExists(flags.code_path_)) { @@ -141,18 +157,18 @@ int Coder::Init(const CoderFlags &flags) const { } config->set_code_path(flags.code_path_); if (flags.code_path_.empty()) { - std::string path = ".." + slash + config->module_name(); + std::string path = ".." + slash + config->proj_dir(); config->set_code_path(path); } else { if (flags.code_path_.substr(flags.code_path_.size() - 1, 1) != slash) { - std::string path = flags.code_path_ + slash + config->module_name(); + std::string path = flags.code_path_ + slash + config->proj_dir(); config->set_code_path(path); } else { - std::string path = flags.code_path_ + config->module_name(); + std::string path = flags.code_path_ + config->proj_dir(); config->set_code_path(path); } } - return InitProjDirs(flags.code_path_, config->module_name()) != RET_ERROR; + return InitProjDirs(flags.code_path_, config->proj_dir()) != RET_ERROR; }); if (!std::all_of(parsers.begin(), parsers.end(), [](auto &parser) -> bool { return parser(); })) { @@ -162,17 +178,15 @@ int Coder::Init(const CoderFlags &flags) const { } return RET_ERROR; } - config->set_module_name(kModelName); - auto print_parameter = [](auto name, auto value) { MS_LOG(INFO) << std::setw(20) << std::left << name << "= " << value; }; print_parameter("modelPath", flags.model_path_); + print_parameter("projectName", config->proj_dir()); print_parameter("target", config->target()); print_parameter("codePath", config->code_path()); print_parameter("codeMode", config->code_mode()); - print_parameter("codeModuleName", config->module_name()); print_parameter("debugMode", config->debug_mode()); return RET_OK; diff --git a/mindspore/lite/micro/coder/config.h b/mindspore/lite/micro/coder/config.h index 6b62b8da265..1e68699009f 100644 --- a/mindspore/lite/micro/coder/config.h +++ b/mindspore/lite/micro/coder/config.h @@ -30,9 +30,6 @@ class Configurator { return &configurator; } - void set_module_name(const std::string &module_name) { module_name_ = module_name; } - std::string module_name() const { return module_name_; } - void set_code_path(const std::string &code_path) { code_path_ = code_path; } std::string code_path() const { return code_path_; } @@ -48,16 +45,18 @@ class Configurator { void set_support_parallel(bool parallel) { support_parallel_ = parallel; } bool support_parallel() const { return support_parallel_; } + int ParseProjDir(std::string model_path); + std::string proj_dir() const { return proj_dir_; } + private: Configurator() = default; ~Configurator() = default; - - std::string module_name_; std::string code_path_; Target target_{kTargetUnknown}; CodeMode code_mode_{Code_Unknown}; bool support_parallel_{false}; bool debug_mode_{false}; + std::string proj_dir_; }; } // namespace mindspore::lite::micro diff --git a/mindspore/lite/micro/coder/context.cc b/mindspore/lite/micro/coder/context.cc index d2ea678f5d1..e9a0b8b1069 100644 --- a/mindspore/lite/micro/coder/context.cc +++ b/mindspore/lite/micro/coder/context.cc @@ -15,17 +15,14 @@ */ #include "coder/context.h" -#include "coder/config.h" #include "coder/allocator/allocator.h" namespace mindspore::lite::micro { CoderContext::CoderContext() { - Configurator *config = Configurator::GetInstance(); - std::string module_name = config->module_name(); - this->input_name_ = module_name + "_I"; - this->output_name_ = module_name + "_O"; - this->buffer_name_ = module_name + "_B"; - this->weight_name_ = module_name + "_W"; + this->input_name_ = "g_Input"; + this->output_name_ = "g_Output"; + this->buffer_name_ = "g_Buffer"; + this->weight_name_ = "g_Weight"; } void CoderContext::AppendCode(const std::string &codeBlock) { this->code_blocks_.emplace_back(codeBlock); } diff --git a/mindspore/lite/micro/coder/generator/component/cmake_component.cc b/mindspore/lite/micro/coder/generator/component/cmake_component.cc index 4f37eff46ab..c43e87900d7 100644 --- a/mindspore/lite/micro/coder/generator/component/cmake_component.cc +++ b/mindspore/lite/micro/coder/generator/component/cmake_component.cc @@ -31,7 +31,7 @@ void CodeCMakeNetLibrary(std::ofstream &ofs, const std::unique_ptr for (const std::string &c_file : ctx->c_files()) { ofs << " " << c_file << ".o\n"; } - ofs << " net_weight.c.o\n" + ofs << " weight.c.o\n" << " net.c.o\n" << " session.cc.o\n" << " tensor.cc.o\n"; diff --git a/mindspore/lite/micro/coder/generator/component/common_component.cc b/mindspore/lite/micro/coder/generator/component/common_component.cc index 222e7fc27bb..fa46f32d85d 100644 --- a/mindspore/lite/micro/coder/generator/component/common_component.cc +++ b/mindspore/lite/micro/coder/generator/component/common_component.cc @@ -73,29 +73,30 @@ void CodeCopyOutputsImplement(std::ofstream &ofs, const std::unique_ptrSize() << ");\n"; } - ofs << " outputs[0] = net_B;\n" - " return RET_OK;\n" + ofs << " return RET_OK;\n" "}\n\n"; } -void CodeInputState(std::ofstream &ofs, const std::string &module_name) { +void CodeInputState(std::ofstream &ofs) { ofs << "/**\n" << " * set input tensors\n" << " * @param inputs, the input data ptr's array of the model, the tensors' count of input may be greater than " "one.\n" << " * @param num, the input data's number of the model.\n" << " **/\n" - << "int " << module_name << "_SetInputs(const void **inputs, int num);\n\n"; + << "int " + << "SetInputs(const void **inputs, int num);\n\n"; } -void CodeInputImplement(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr &ctx) { +void CodeInputImplement(std::ofstream &ofs, const std::unique_ptr &ctx) { // input tensors std::vector inputs = ctx->graph_inputs(); for (size_t i = 0; i < inputs.size(); ++i) { ofs << "static const unsigned char *" << ctx->input_name() + std::to_string(i) << " = 0;\n"; } size_t size = inputs.size(); - ofs << "int " << module_name << "_SetInputs(const void **inputs, int num) {\n" + ofs << "int " + << "SetInputs(const void **inputs, int num) {\n" << " if (inputs == NULL) {\n" " return RET_ERROR;\n" " }\n" @@ -108,15 +109,15 @@ void CodeInputImplement(std::ofstream &ofs, const std::string &module_name, cons ofs << " return RET_OK;\n}\n"; } -void CodeGraphQuantArgsState(std::ofstream &ofs, const std::string &module_name) { +void CodeGraphQuantArgsState(std::ofstream &ofs) { ofs << "/**\n" << " * get input and output QuantArgs of the model \n" << " **/\n" - << "GraphQuantArgs " << module_name << "_GetInOutQuantArgs();\n\n"; + << "GraphQuantArgs " + << "GetInOutQuantArgs();\n\n"; } -void CodeGraphQuantArgsImplement(std::ofstream &ofs, const std::string &module_name, - const std::unique_ptr &ctx) { +void CodeGraphQuantArgsImplement(std::ofstream &ofs, const std::unique_ptr &ctx) { std::vector graph_inputs = ctx->graph_inputs(); Tensor *in_tensor = graph_inputs.at(kInputIndex); MS_CHECK_PTR_IF_NULL(in_tensor); @@ -129,36 +130,41 @@ void CodeGraphQuantArgsImplement(std::ofstream &ofs, const std::string &module_n MS_LOG(ERROR) << "code model quant args failed"; return; } - ofs << "GraphQuantArgs " << module_name << "_GetInOutQuantArgs() {\n" + ofs << "GraphQuantArgs " + << "GetInOutQuantArgs() {\n" << "\t\tGraphQuantArgs quan_args = { " << in_quant_args.at(0).scale << ", " << out_quant_args.at(0).scale << ", " << in_quant_args.at(0).zeroPoint << ", " << out_quant_args.at(0).zeroPoint << "};\n" << "\t\treturn quan_args;\n" << "}\n"; } -void CodeManageResourceState(std::ofstream &ofs, const std::string &module_name) { +void CodeManageResourceState(std::ofstream &ofs) { ofs << "/**\n" << " * get the memory space size of the inference.\n" << " **/\n" - << "int " << module_name << "_GetBufferSize();\n"; + << "int " + << "GetBufferSize();\n"; ofs << "/**\n" << " * set the memory space for the inference\n" << " **/\n" - << "int " << module_name << "_SetBuffer(void *buffer);\n\n"; + << "int " + << "SetBuffer(void *buffer);\n\n"; ofs << "/**\n" << " * free the memory of packed weights, and set the membuf buffer and input address to NULL\n" << " **/\n" - << "void " << module_name << "_FreeResource();\n"; + << "void " + << "FreeResource();\n"; } -void CodeInitResourceImplement(std::ofstream &ofs, const std::string &module_name, - const std::unique_ptr &ctx) { - ofs << "int " << module_name << "_GetBufferSize() {\n" +void CodeInitResourceImplement(std::ofstream &ofs, const std::unique_ptr &ctx) { + ofs << "int " + << "GetBufferSize() {\n" << " return " << ctx->total_buffer_size() << ";\n" << "}\n"; - ofs << "int " << module_name << "_SetBuffer( void *buffer) {\n"; + ofs << "int " + << "SetBuffer( void *buffer) {\n"; ofs << " if (buffer == NULL) {\n" " return RET_ERROR;\n" " }\n"; @@ -167,9 +173,9 @@ void CodeInitResourceImplement(std::ofstream &ofs, const std::string &module_nam "}\n"; } -void CodeFreeResourceImplement(std::ofstream &ofs, const std::string &module_name, - const std::unique_ptr &ctx) { - ofs << "void " << module_name << "_FreeResource() {\n"; +void CodeFreeResourceImplement(std::ofstream &ofs, const std::unique_ptr &ctx) { + ofs << "void " + << "FreeResource() {\n"; ofs << " " << ctx->buffer_name() << "= NULL;\n"; std::vector inputs = ctx->graph_inputs(); size_t size = inputs.size(); @@ -194,11 +200,12 @@ void CodeFreeResourceImplement(std::ofstream &ofs, const std::string &module_nam ofs << "}\n"; } -void CodeInferenceState(std::ofstream &ofs, const std::string &module_name) { +void CodeInferenceState(std::ofstream &ofs) { ofs << "/**\n" << " * net inference function\n" << " **/\n" - << "void " << module_name << "_Inference();\n\n"; + << "void " + << "Inference();\n\n"; } } // namespace mindspore::lite::micro diff --git a/mindspore/lite/micro/coder/generator/component/common_component.h b/mindspore/lite/micro/coder/generator/component/common_component.h index 3abe3499cbf..5162ab331ac 100644 --- a/mindspore/lite/micro/coder/generator/component/common_component.h +++ b/mindspore/lite/micro/coder/generator/component/common_component.h @@ -31,21 +31,18 @@ void CodeSessionCompileGraph(std::ofstream &ofs, const std::unique_ptr &ctx); -void CodeInputState(std::ofstream &ofs, const std::string &module_name); -void CodeInputImplement(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr &ctx); +void CodeInputState(std::ofstream &ofs); +void CodeInputImplement(std::ofstream &ofs, const std::unique_ptr &ctx); -void CodeGraphQuantArgsState(std::ofstream &ofs, const std::string &module_name); -void CodeGraphQuantArgsImplement(std::ofstream &ofs, const std::string &module_name, - const std::unique_ptr &ctx); +void CodeGraphQuantArgsState(std::ofstream &ofs); +void CodeGraphQuantArgsImplement(std::ofstream &ofs, const std::unique_ptr &ctx); -void CodeManageResourceState(std::ofstream &ofs, const std::string &module_name); -void CodeInitResourceImplement(std::ofstream &ofs, const std::string &module_name, - const std::unique_ptr &ctx); +void CodeManageResourceState(std::ofstream &ofs); +void CodeInitResourceImplement(std::ofstream &ofs, const std::unique_ptr &ctx); -void CodeFreeResourceImplement(std::ofstream &ofs, const std::string &module_name, - const std::unique_ptr &ctx); +void CodeFreeResourceImplement(std::ofstream &ofs, const std::unique_ptr &ctx); -void CodeInferenceState(std::ofstream &ofs, const std::string &module_name); +void CodeInferenceState(std::ofstream &ofs); } // namespace mindspore::lite::micro #endif // MINDSPORE_LITE_MICRO_CODER_GENERATOR_COMMON_COMPONENT_H_ diff --git a/mindspore/lite/micro/coder/generator/component/const_blocks/msession.cc b/mindspore/lite/micro/coder/generator/component/const_blocks/msession.cc index abfd21b7883..eb9acef8a30 100644 --- a/mindspore/lite/micro/coder/generator/component/const_blocks/msession.cc +++ b/mindspore/lite/micro/coder/generator/component/const_blocks/msession.cc @@ -104,9 +104,9 @@ int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &af for (size_t i = 0; i < inputs_.size(); ++i) { inputs_data[i] = inputs_[i]->MutableData(); } - net_SetInputs(inputs_data, inputs_.size()); + SetInputs(inputs_data, inputs_.size()); - net_Inference(); + Inference(); void *outputs_data[outputs_.size()]; for (size_t i = 0; i < outputs_.size(); ++i) { @@ -118,7 +118,7 @@ int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &af } LiteSession::~LiteSession() { - net_FreeResource(); + FreeResource(); if (runtime_buffer_ != nullptr) { free(runtime_buffer_); runtime_buffer_ = nullptr; @@ -141,12 +141,12 @@ LiteSession::~LiteSession() { } int LiteSession::InitRuntimeBuffer() { - int buffer_size = net_GetBufferSize(); + int buffer_size = GetBufferSize(); runtime_buffer_ = malloc(buffer_size); if (runtime_buffer_ == nullptr) { return RET_ERROR; } - int ret = net_SetBuffer(runtime_buffer_); + int ret = SetBuffer(runtime_buffer_); if (ret != RET_OK) { return RET_ERROR; } @@ -215,7 +215,7 @@ session::LiteSession *session::LiteSession::CreateSession(const char *net_buf, s if (ret != lite::RET_OK) { return nullptr; } - net_Init(const_cast(net_buf), size); + Init(const_cast(net_buf), size); return session; } } // namespace mindspore diff --git a/mindspore/lite/micro/coder/generator/component/parallel_component.cc b/mindspore/lite/micro/coder/generator/component/parallel_component.cc index f01322f61a4..8a73443f0f9 100644 --- a/mindspore/lite/micro/coder/generator/component/parallel_component.cc +++ b/mindspore/lite/micro/coder/generator/component/parallel_component.cc @@ -19,7 +19,7 @@ namespace mindspore::lite::micro { -void CodeCreateThreadPool(std::ofstream &ofs, const std::string &module_name) { +void CodeCreateThreadPool(std::ofstream &ofs) { ofs << " int thread_num = 4;\n" " BindMode bind_mode = NO_BIND_MODE;\n" " if (argc >= 6) {\n" @@ -31,7 +31,8 @@ void CodeCreateThreadPool(std::ofstream &ofs, const std::string &module_name) { " MICRO_ERROR(\"create thread pool failed\");\n" " return RET_ERROR;\n" " }\n" - << " ret = " << module_name << "_SetThreadPool(thread_pool);\n" + << " ret = " + << "SetThreadPool(thread_pool);\n" << " if (ret != RET_OK) {\n" " MICRO_ERROR(\"set global thread pool failed\");\n" " return RET_ERROR;\n" @@ -41,16 +42,18 @@ void CodeCreateThreadPool(std::ofstream &ofs, const std::string &module_name) { void CodeDestroyThreadPool(std::ofstream &ofs) { ofs << " DestroyThreadPool(thread_pool);\n"; } -void CodeSetGlobalThreadPoolState(std::ofstream &ofs, const std::string &module_name) { +void CodeSetGlobalThreadPoolState(std::ofstream &ofs) { ofs << "/*\n" " * set global thread pool, which is created by user\n" " */\n" - << "int " << module_name << "_SetThreadPool(struct ThreadPool *thread_pool);\n\n"; + << "int " + << "SetThreadPool(struct ThreadPool *thread_pool);\n\n"; } -void CodeSetGlobalThreadPoolImplement(std::ofstream &ofs, const std::string &module_name) { +void CodeSetGlobalThreadPoolImplement(std::ofstream &ofs) { ofs << "struct ThreadPool *g_thread_pool = NULL;\n" - << "int " << module_name << "_SetThreadPool(struct ThreadPool *thread_pool) {\n" + << "int " + << "SetThreadPool(struct ThreadPool *thread_pool) {\n" << " if (thread_pool == NULL) {\n" " return RET_ERROR;\n" " }\n" diff --git a/mindspore/lite/micro/coder/generator/component/parallel_component.h b/mindspore/lite/micro/coder/generator/component/parallel_component.h index f92cad26ec8..8a2aaf1d073 100644 --- a/mindspore/lite/micro/coder/generator/component/parallel_component.h +++ b/mindspore/lite/micro/coder/generator/component/parallel_component.h @@ -22,13 +22,13 @@ namespace mindspore::lite::micro { -void CodeCreateThreadPool(std::ofstream &ofs, const std::string &module_name); +void CodeCreateThreadPool(std::ofstream &ofs); void CodeDestroyThreadPool(std::ofstream &ofs); -void CodeSetGlobalThreadPoolState(std::ofstream &ofs, const std::string &module_name); +void CodeSetGlobalThreadPoolState(std::ofstream &ofs); -void CodeSetGlobalThreadPoolImplement(std::ofstream &ofs, const std::string &module_name); +void CodeSetGlobalThreadPoolImplement(std::ofstream &ofs); } // namespace mindspore::lite::micro diff --git a/mindspore/lite/micro/coder/generator/component/train_component.cc b/mindspore/lite/micro/coder/generator/component/train_component.cc index 866b284418d..5e663de25c1 100644 --- a/mindspore/lite/micro/coder/generator/component/train_component.cc +++ b/mindspore/lite/micro/coder/generator/component/train_component.cc @@ -39,24 +39,23 @@ void CodeTrainParams(std::ofstream &ofs) { "};\n\n"; } -void CodeFeaturesState(std::ofstream &ofs, const std::string &module_name) { +void CodeFeaturesState(std::ofstream &ofs) { ofs << "/**\n" " *\n" " * @param size, return the number of features\n" " * @return, the address of features\n" " */\n" - << "FeatureParam *" << module_name << "_GetFeatures(int *size);\n\n"; + << "FeatureParam *GetFeatures(int *size);\n\n"; ofs << "/**\n" " *\n" " * @param features, the address of features\n" " * @param size, the number of features\n" " * @return, status\n" " */\n" - << "int " << module_name << "_UpdateFeatures(FeatureParam *features, int size);\n\n"; + << "int UpdateFeatures(FeatureParam *features, int size);\n\n"; } -void CodeFeaturesImplement(std::ofstream &ofs, const std::string &module_name, - const std::unique_ptr &ctx) { +void CodeFeaturesImplement(std::ofstream &ofs, const std::unique_ptr &ctx) { size_t features_num = 0; ofs << "static FeatureParam feature_params[] = {\n"; for (const auto &item : ctx->saved_weights()) { @@ -72,12 +71,13 @@ void CodeFeaturesImplement(std::ofstream &ofs, const std::string &module_name, } ofs << "};\n"; - ofs << "FeatureParam *" << module_name << "_GetFeatures(int *size) {\n" + ofs << "FeatureParam *GetFeatures(int *size) {\n" << " *size = " << features_num << ";\n" << " return feature_params;\n" "}\n\n"; - ofs << "int " << module_name << "_UpdateFeatures(FeatureParam *features, int size) {\n" + ofs << "int " + << "UpdateFeatures(FeatureParam *features, int size) {\n" << " for (int i = 0; i < size; ++i) {\n" " FeatureParam *src = features + i;\n" " FeatureParam dst;\n" @@ -106,22 +106,22 @@ void CodeFeaturesImplement(std::ofstream &ofs, const std::string &module_name, "}\n\n"; } -void CodeTrainState(std::ofstream &ofs, const std::string &module_name) { - ofs << "/**\n" - " * Train Function\n" - " * @param epoch, the train epoch\n" - " * @param iterations, which is equal to batch_num, the number of iterations of each epoch\n" - " * @param use_train_param, default parameters already exists, such as the momentum, user can update these\n" - " * parameters to improve the accuracy\n" - " * @param parameter, the TrainParameter contains epsilon/beta1/beta2\n" - " * @return status\n" - " */\n" - << "int " << module_name - << "_Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter *parameter, " - "const struct EarlyStop *early_stop);\n\n"; +void CodeTrainState(std::ofstream &ofs) { + ofs + << "/**\n" + " * Train Function\n" + " * @param epoch, the train epoch\n" + " * @param iterations, which is equal to batch_num, the number of iterations of each epoch\n" + " * @param use_train_param, default parameters already exists, such as the momentum, user can update these\n" + " * parameters to improve the accuracy\n" + " * @param parameter, the TrainParameter contains epsilon/beta1/beta2\n" + " * @return status\n" + " */\n" + << "int Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter *parameter, " + "const struct EarlyStop *early_stop);\n\n"; } -void CodeTrainImplement(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr &ctx) { +void CodeTrainImplement(std::ofstream &ofs, const std::unique_ptr &ctx) { std::vector inputs = ctx->graph_inputs(); size_t inputs_num = inputs.size(); auto inputs_tostring = [&]() { @@ -151,8 +151,7 @@ void CodeTrainImplement(std::ofstream &ofs, const std::string &module_name, cons } return result; }; - ofs << "int " << module_name - << "_Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter " + ofs << "int Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter " "*parameter, const struct EarlyStop *early_stop) {\n" " if (iterations <= 0 || epoch <= 0) {\n" " MICRO_ERROR(\"error iterations or epoch!, epoch:%d, iterations:%d\", epoch, iterations);\n" @@ -169,9 +168,12 @@ void CodeTrainImplement(std::ofstream &ofs, const std::string &module_name, cons << " float loss = 0;\n" << " for (int j = 0; j < iterations; ++j) {\n" << " " << offset_inputs() << "\n" - << " " << module_name << "_SetInputs(input_ptr, " << inputs_num << ");\n" - << " " << module_name << "_Inference();\n" - << " loss = " << module_name << "_ComputeLossAndGradient();\n" + << " " + << "_SetInputs(input_ptr, " << inputs_num << ");\n" + << " " + << "_Inference();\n" + << " loss = " + << "ComputeLossAndGradient();\n" << " }\n" " }\n" " return RET_OK;\n" diff --git a/mindspore/lite/micro/coder/generator/component/train_component.h b/mindspore/lite/micro/coder/generator/component/train_component.h index c3c00fe4d4e..0f5aea1c615 100644 --- a/mindspore/lite/micro/coder/generator/component/train_component.h +++ b/mindspore/lite/micro/coder/generator/component/train_component.h @@ -28,12 +28,11 @@ namespace mindspore::lite::micro { void CodeTrainParams(std::ofstream &ofs); -void CodeFeaturesState(std::ofstream &ofs, const std::string &module_name); -void CodeFeaturesImplement(std::ofstream &ofs, const std::string &module_name, - const std::unique_ptr &ctx); +void CodeFeaturesState(std::ofstream &ofs); +void CodeFeaturesImplement(std::ofstream &ofs, const std::unique_ptr &ctx); -void CodeTrainState(std::ofstream &ofs, const std::string &module_name); -void CodeTrainImplement(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr &ctx); +void CodeTrainState(std::ofstream &ofs); +void CodeTrainImplement(std::ofstream &ofs, const std::unique_ptr &ctx); } // namespace mindspore::lite::micro #endif // MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_ diff --git a/mindspore/lite/micro/coder/generator/component/weight_component.cc b/mindspore/lite/micro/coder/generator/component/weight_component.cc index ddad03cfbb4..e606d807777 100644 --- a/mindspore/lite/micro/coder/generator/component/weight_component.cc +++ b/mindspore/lite/micro/coder/generator/component/weight_component.cc @@ -87,16 +87,16 @@ void CodeModelParamsForNet(std::ofstream &hofs, std::ofstream &cofs, const std:: cofs << "\n"; } -void CodeInitWeightState(std::ofstream &ofs, const std::string &module_name) { +void CodeInitWeightState(std::ofstream &ofs) { ofs << "/**\n" << " * @param weight_buffer, the address of the weight binary file\n" << " * @param weight_size, the size of the model file in bytes\n" << " **/\n" - << "int " << module_name << "_Init(void *weight_buffer, int weight_size);\n\n"; + << "int Init(void *weight_buffer, int weight_size);\n\n"; } -void CodeWeightInitFunc(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr &ctx) { - ofs << "int " << module_name << "_Init(void *weight_buffer, int weight_size) {\n" +void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr &ctx) { + ofs << "int Init(void *weight_buffer, int weight_size) {\n" << " if (weight_buffer == NULL) {\n" << " return RET_ERROR;\n" << " }\n"; diff --git a/mindspore/lite/micro/coder/generator/component/weight_component.h b/mindspore/lite/micro/coder/generator/component/weight_component.h index 125ca147155..2aa647aba28 100644 --- a/mindspore/lite/micro/coder/generator/component/weight_component.h +++ b/mindspore/lite/micro/coder/generator/component/weight_component.h @@ -35,8 +35,8 @@ void CodeModelParamsData(std::ofstream &ofs, const std::map &saved_weights, const std::string &net_file); void CodeModelParamsForNet(std::ofstream &hofs, std::ofstream &cofs, const std::unique_ptr &ctx); -void CodeInitWeightState(std::ofstream &ofs, const std::string &module_name); -void CodeWeightInitFunc(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr &ctx); +void CodeInitWeightState(std::ofstream &ofs); +void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr &ctx); } // namespace mindspore::lite::micro diff --git a/mindspore/lite/micro/coder/generator/generator.cc b/mindspore/lite/micro/coder/generator/generator.cc index b6af447e6c9..e990a59b1a1 100644 --- a/mindspore/lite/micro/coder/generator/generator.cc +++ b/mindspore/lite/micro/coder/generator/generator.cc @@ -46,10 +46,9 @@ int WriteContentToFile(const std::string &file, const std::string &content) { Generator::Generator(std::unique_ptr ctx) { ctx_ = std::move(ctx); this->config_ = Configurator::GetInstance(); - std::string module_name = config_->module_name(); - this->net_inc_hfile_ = module_name + ".h"; - this->net_src_cfile_ = module_name + ".c"; - this->net_weight_hfile_ = module_name + "_weight.h"; + this->net_inc_hfile_ = "net.h"; + this->net_src_cfile_ = "net.c"; + this->net_weight_hfile_ = "weight.h"; this->net_src_file_path_ = config_->code_path() + kSourcePath; this->net_main_file_path_ = config_->code_path() + kBenchmarkPath; origin_umask_ = umask(user_umask_); @@ -60,7 +59,7 @@ Generator::~Generator() { (void)umask(origin_umask_); } void Generator::CodeNetRunFunc(std::ofstream &ofs) { // generate net inference code - ofs << "void " << config_->module_name() << "_Inference() {\n"; + ofs << "void Inference() {\n"; if (config_->support_parallel()) { ofs << " const int g_thread_num = GetCurrentThreadNum(g_thread_pool);\n"; } else { @@ -143,7 +142,7 @@ int Generator::CodeWeightFile() { CodeWeightFileHeader(hofs, ctx_); // weight source file - std::string cfile = net_src_file_path_ + config_->module_name() + "_weight.c"; + std::string cfile = net_src_file_path_ + "weight.c"; std::ofstream cofs(cfile); MS_CHECK_TRUE(!cofs.bad(), "filed to open file"); MS_LOG(INFO) << "write " << cfile; @@ -152,10 +151,10 @@ int Generator::CodeWeightFile() { cofs << "unsigned char * " << ctx_->buffer_name() << " = 0 ; \n"; if (config_->target() != kARM32M) { - std::string net_file = net_src_file_path_ + config_->module_name() + ".net"; + std::string net_file = net_src_file_path_ + "net.bin"; SaveDataToNet(ctx_->saved_weights(), net_file); CodeModelParamsForNet(hofs, cofs, ctx_); - CodeWeightInitFunc(cofs, config_->module_name(), ctx_); + CodeWeightInitFunc(cofs, ctx_); } else { CodeModelParamsState(hofs, ctx_->saved_weights()); CodeModelParamsData(cofs, ctx_->saved_weights()); diff --git a/mindspore/lite/micro/coder/generator/inference/inference_generator.cc b/mindspore/lite/micro/coder/generator/inference/inference_generator.cc index 8a72b1b4d26..80b96656ceb 100644 --- a/mindspore/lite/micro/coder/generator/inference/inference_generator.cc +++ b/mindspore/lite/micro/coder/generator/inference/inference_generator.cc @@ -35,19 +35,19 @@ int InferenceGenerator::CodeNetHFile() { ofs << "#include \"thread_pool.h\"\n"; } ofs << kExternCpp; - CodeInputState(ofs, config_->module_name()); + CodeInputState(ofs); CodeCopyOutputsState(ofs); if (is_get_quant_args_) { - CodeGraphQuantArgsState(ofs, config_->module_name()); + CodeGraphQuantArgsState(ofs); } if (config_->support_parallel()) { - CodeSetGlobalThreadPoolState(ofs, config_->module_name()); + CodeSetGlobalThreadPoolState(ofs); } if (config_->target() != kARM32M) { - CodeInitWeightState(ofs, config_->module_name()); + CodeInitWeightState(ofs); } - CodeManageResourceState(ofs, config_->module_name()); - CodeInferenceState(ofs, config_->module_name()); + CodeManageResourceState(ofs); + CodeInferenceState(ofs); ofs << kEndExternCpp; return RET_OK; } @@ -64,14 +64,14 @@ int InferenceGenerator::CodeNetCFile() { ofs << "#include \"" << kDebugUtils << "\"\n"; } if (config_->support_parallel()) { - CodeSetGlobalThreadPoolImplement(ofs, config_->module_name()); + CodeSetGlobalThreadPoolImplement(ofs); } - CodeInputImplement(ofs, config_->module_name(), ctx_); + CodeInputImplement(ofs, ctx_); CodeCopyOutputsImplement(ofs, ctx_); - CodeInitResourceImplement(ofs, config_->module_name(), ctx_); - CodeFreeResourceImplement(ofs, config_->module_name(), ctx_); + CodeInitResourceImplement(ofs, ctx_); + CodeFreeResourceImplement(ofs, ctx_); if (is_get_quant_args_) { - CodeGraphQuantArgsImplement(ofs, config_->module_name(), ctx_); + CodeGraphQuantArgsImplement(ofs, ctx_); } CodeNetRunFunc(ofs); ofs.close(); diff --git a/mindspore/lite/micro/coder/generator/train/train_generator.cc b/mindspore/lite/micro/coder/generator/train/train_generator.cc index f9991892577..e11ea01667d 100644 --- a/mindspore/lite/micro/coder/generator/train/train_generator.cc +++ b/mindspore/lite/micro/coder/generator/train/train_generator.cc @@ -24,7 +24,7 @@ namespace mindspore::lite::micro { void TrainGenerator::CodeGradientFunc(std::ofstream &ofs) const { - ofs << "float " << config_->module_name() << "_ComputeLossAndGradient() {\n"; + ofs << "float ComputeLossAndGradient() {\n"; ofs << " float loss = 0;\n"; for (const auto &block : ctx_->train_blocks()) { ofs << "\t{\n" << block << "\t}\n"; @@ -44,14 +44,14 @@ int TrainGenerator::CodeNetHFile() { } ofs << "#include \"microtensor.h\"\n\n"; CodeTrainParams(ofs); - CodeInputState(ofs, config_->module_name()); + CodeInputState(ofs); if (config_->target() != kARM32M) { - CodeInitWeightState(ofs, config_->module_name()); + CodeInitWeightState(ofs); } - CodeManageResourceState(ofs, config_->module_name()); - CodeInferenceState(ofs, config_->module_name()); - CodeFeaturesState(ofs, config_->module_name()); - CodeTrainState(ofs, config_->module_name()); + CodeManageResourceState(ofs); + CodeInferenceState(ofs); + CodeFeaturesState(ofs); + CodeTrainState(ofs); return RET_OK; } @@ -60,13 +60,13 @@ int TrainGenerator::CodeNetCFile() { std::ofstream ofs(net_impl_file); MS_CHECK_TRUE(!ofs.bad(), "filed to open file"); MS_LOG(INFO) << "write " << net_impl_file; - CodeInputImplement(ofs, config_->module_name(), ctx_); - CodeInitResourceImplement(ofs, config_->module_name(), ctx_); - CodeFreeResourceImplement(ofs, config_->module_name(), ctx_); - CodeFeaturesImplement(ofs, config_->module_name(), ctx_); + CodeInputImplement(ofs, ctx_); + CodeInitResourceImplement(ofs, ctx_); + CodeFreeResourceImplement(ofs, ctx_); + CodeFeaturesImplement(ofs, ctx_); CodeNetRunFunc(ofs); CodeGradientFunc(ofs); - CodeTrainImplement(ofs, config_->module_name(), ctx_); + CodeTrainImplement(ofs, ctx_); ofs.close(); return RET_OK; } diff --git a/mindspore/lite/micro/coder/utils/dir_utils.cc b/mindspore/lite/micro/coder/utils/dir_utils.cc index 6bc8df0cab5..2bc527d9829 100644 --- a/mindspore/lite/micro/coder/utils/dir_utils.cc +++ b/mindspore/lite/micro/coder/utils/dir_utils.cc @@ -32,7 +32,7 @@ constexpr _mode_t kMicroDirMode = 0777; constexpr __mode_t kMicroDirMode = 0777; #endif -static std::array kWorkDirs = {"src", "benchmark"}; +static std::array kWorkDirs = {"src", "benchmark"}; bool DirExists(const std::string &dir_path) { struct stat file_info; @@ -76,18 +76,18 @@ static int MkMicroDir(const std::string ¤tDir) { return RET_OK; } -int InitProjDirs(const std::string &pro_root_dir, const std::string &module_name) { +int InitProjDirs(const std::string &project_root_dir, const std::string &proj_name) { #if defined(_WIN32) || defined(_WIN64) std::ofstream pro_file; - std::string read_me_file = pro_root_dir + "\\readMe.txt"; + std::string read_me_file = project_root_dir + "\\readMe.txt"; pro_file.open(read_me_file.c_str()); pro_file << "This is a directory for generating coding files. Do not edit !!!\n"; #else std::ifstream pro_file; - pro_file.open(pro_root_dir.c_str()); + pro_file.open(project_root_dir.c_str()); #endif if (!pro_file.is_open()) { - MS_LOG(ERROR) << pro_root_dir << ": model's root dir not exists or have no access to open, please check it!!!"; + MS_LOG(ERROR) << project_root_dir << ": model's root dir not exists or have no access to open, please check it!!!"; pro_file.close(); return RET_ERROR; } @@ -95,11 +95,10 @@ int InitProjDirs(const std::string &pro_root_dir, const std::string &module_name // 1. coderDir 2.WorkRootDir 3. WorkChildDir std::string current_dir; std::string slashCh = std::string(kSlash); - if (pro_root_dir.back() == slashCh.back()) { - current_dir = pro_root_dir + module_name; - } else { - current_dir = pro_root_dir + slashCh + module_name; + if (project_root_dir.back() != slashCh.back()) { + current_dir = project_root_dir + slashCh; } + current_dir += proj_name; std::string work_dir = current_dir; STATUS ret = MkMicroDir(current_dir); if (ret == RET_ERROR) { diff --git a/mindspore/lite/micro/coder/utils/dir_utils.h b/mindspore/lite/micro/coder/utils/dir_utils.h index 6a378d5a4e0..31a311bd5b9 100644 --- a/mindspore/lite/micro/coder/utils/dir_utils.h +++ b/mindspore/lite/micro/coder/utils/dir_utils.h @@ -24,7 +24,7 @@ static const char kSlash[] = "\\"; static const char kSlash[] = "/"; #endif -int InitProjDirs(const std::string &project_root_dir, const std::string &module_name); +int InitProjDirs(const std::string &project_root_dir, const std::string &proj_name); bool DirExists(const std::string &dir_path);