From ce1dd5eec06b612ee56f1e3fa657f497a4f6f899 Mon Sep 17 00:00:00 2001 From: yangjie159 Date: Thu, 18 Feb 2021 11:48:13 +0800 Subject: [PATCH] feature, micro support train --- mindspore/lite/micro/cmake/file_list.cmake | 3 + mindspore/lite/micro/coder/coder.cc | 5 +- mindspore/lite/micro/coder/coder_config.h | 2 +- .../generator/component/train_component.cc | 180 ++++++++++++++++++ .../generator/component/train_component.h | 39 ++++ .../lite/micro/coder/generator/generator.cc | 10 +- .../inference/inference_generator.cc | 4 +- .../coder/generator/train/train_generator.cc | 108 +++++++++++ .../coder/generator/train/train_generator.h | 39 ++++ .../lite/micro/coder/opcoders/op_coder.h | 2 +- .../micro/coder/opcoders/op_coder_builder.cc | 2 +- mindspore/lite/micro/coder/session.cc | 16 +- mindspore/lite/micro/coder/train.cc | 95 +++++++++ mindspore/lite/micro/coder/train.h | 33 ++++ .../lite/micro/coder/utils/coder_utils.cc | 28 --- .../lite/micro/coder/utils/coder_utils.h | 2 - 16 files changed, 521 insertions(+), 47 deletions(-) create mode 100644 mindspore/lite/micro/coder/generator/component/train_component.cc create mode 100644 mindspore/lite/micro/coder/generator/component/train_component.h create mode 100644 mindspore/lite/micro/coder/generator/train/train_generator.cc create mode 100644 mindspore/lite/micro/coder/generator/train/train_generator.h create mode 100644 mindspore/lite/micro/coder/train.cc create mode 100644 mindspore/lite/micro/coder/train.h diff --git a/mindspore/lite/micro/cmake/file_list.cmake b/mindspore/lite/micro/cmake/file_list.cmake index 8cda5edf916..7e04e30bf37 100644 --- a/mindspore/lite/micro/cmake/file_list.cmake +++ b/mindspore/lite/micro/cmake/file_list.cmake @@ -4,6 +4,7 @@ set(CODER_SRC ${MICRO_DIR}/coder/context.cc ${MICRO_DIR}/coder/graph.cc ${MICRO_DIR}/coder/session.cc + ${MICRO_DIR}/coder/train.cc ) set(CODER_ALLOCATOR_SRC @@ -14,10 +15,12 @@ set(CODER_ALLOCATOR_SRC set(CODER_GENERATOR_SRC ${MICRO_DIR}/coder/generator/generator.cc ${MICRO_DIR}/coder/generator/inference/inference_generator.cc + ${MICRO_DIR}/coder/generator/train/train_generator.cc ${MICRO_DIR}/coder/generator/component/benchmark_component.cc ${MICRO_DIR}/coder/generator/component/common_component.cc ${MICRO_DIR}/coder/generator/component/weight_component.cc ${MICRO_DIR}/coder/generator/component/cmake_component.cc + ${MICRO_DIR}/coder/generator/component/train_component.cc ) set(CODER_OPCODERS_SRC diff --git a/mindspore/lite/micro/coder/coder.cc b/mindspore/lite/micro/coder/coder.cc index bb3be2ebe14..74d99b26f2b 100644 --- a/mindspore/lite/micro/coder/coder.cc +++ b/mindspore/lite/micro/coder/coder.cc @@ -39,7 +39,7 @@ class CoderFlags : public virtual FlagParser { AddFlag(&CoderFlags::code_path_, "codePath", "Input code path", "."); AddFlag(&CoderFlags::code_module_name_, "moduleName", "Input code module name", ""); AddFlag(&CoderFlags::target_, "target", "generateed code target, x86| ARM32M| ARM32A| ARM64", "x86"); - AddFlag(&CoderFlags::code_mode_, "codeMode", "generated code mode, Normal | Android ", "Normal"); + AddFlag(&CoderFlags::code_mode_, "codeMode", "generated code mode, Normal | Inference | Train", "Normal"); AddFlag(&CoderFlags::debug_mode_, "debugMode", "dump perlayer's time cost and tensor, true | false", false); } @@ -87,7 +87,8 @@ int Coder::Run(const std::string &model_path) { int Coder::Init(const CoderFlags &flags) const { static const std::map kTargetMap = { {"x86", kX86}, {"ARM32M", kARM32M}, {"ARM32A", kARM32A}, {"ARM64", kARM64}, {"All", kAllTargets}}; - static const std::map kCodeModeMap = {{"Normal", Code_Normal}, {"Android", Code_Android}}; + static const std::map kCodeModeMap = { + {"Normal", Code_Normal}, {"Inference", Code_Inference}, {"Train", Code_Train}}; Configurator *config = Configurator::GetInstance(); diff --git a/mindspore/lite/micro/coder/coder_config.h b/mindspore/lite/micro/coder/coder_config.h index 5fe9fb2454f..d1b89b6b36a 100644 --- a/mindspore/lite/micro/coder/coder_config.h +++ b/mindspore/lite/micro/coder/coder_config.h @@ -21,7 +21,7 @@ namespace mindspore::lite::micro { enum Target { kX86 = 0, kARM32M = 1, kARM32A = 2, kARM64 = 3, kAllTargets = 4, kTargetUnknown = 99 }; -enum CodeMode { Code_Normal = 0, Code_Android = 1, Code_Unknown = 99 }; +enum CodeMode { Code_Normal = 0, Code_Inference = 1, Code_Train = 2, Code_Unknown = 99 }; class Configurator { public: diff --git a/mindspore/lite/micro/coder/generator/component/train_component.cc b/mindspore/lite/micro/coder/generator/component/train_component.cc new file mode 100644 index 00000000000..66b5512f237 --- /dev/null +++ b/mindspore/lite/micro/coder/generator/component/train_component.cc @@ -0,0 +1,180 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "coder/generator/component/train_component.h" +#include +#include "coder/utils/type_cast.h" + +namespace mindspore::lite::micro { + +void CodeTrainParams(std::ofstream &ofs) { + ofs << "struct TrainParameter {\n" + " float beta1_;\n" + " float beta2_;\n" + " float epsilon_;\n" + "};\n" + "\n" + "enum EarlyStopType {\n" + " Diff = 0,\n" + " WeigthDiff = 1,\n" + " Abs = 2,\n" + "};\n" + "\n" + "struct EarlyStop {\n" + " enum EarlyStopType type;\n" + " float tolerate;\n" + "};\n\n"; +} + +void CodeFeaturesState(std::ofstream &ofs, const std::string &module_name) { + ofs << "/**\n" + " *\n" + " * @param size, return the number of features\n" + " * @return, the address of features\n" + " */\n" + << "FeatureParam *" << module_name << "_GetFeatures(int *size);\n\n"; + ofs << "/**\n" + " *\n" + " * @param features, the address of features\n" + " * @param size, the number of features\n" + " * @return, status\n" + " */\n" + << "int " << module_name << "_UpdateFeatures(FeatureParam *features, int size);\n\n"; +} + +void CodeFeaturesImplement(std::ofstream &ofs, const std::string &module_name, + const std::unique_ptr &ctx) { + size_t features_num = 0; + ofs << "static FeatureParam feature_params[] = {\n"; + for (const auto &item : ctx->saved_weights()) { + std::string addr = item.first; + Tensor *tensor = item.second; + if (tensor->tensor_name().empty()) { + MS_LOG(ERROR) << "exist empty feature"; + continue; + } + ofs << "\t{\"" << tensor->tensor_name() << "\", " << addr << ", " << tensor->ElementsNum() << ", " + << EnumMicroTensorDataType(tensor->data_type()) << "}, \n"; + features_num++; + } + ofs << "};\n"; + + ofs << "FeatureParam *" << module_name << "_GetFeatures(int *size) {\n" + << " *size = " << features_num << ";\n" + << " return feature_params;\n" + "}\n\n"; + + ofs << "int " << module_name << "_UpdateFeatures(FeatureParam *features, int size) {\n" + << " for (int i = 0; i < size; ++i) {\n" + " FeatureParam *src = features + i;\n" + " FeatureParam dst;\n" + " // find the dst feature\n" + " bool is_find = false;\n" + << " for (int j = 0; j < " << features_num << "; ++j) {\n" + << " if (strcmp(src->name, feature_params[j].name) == 0) {\n" + " dst = feature_params[j];\n" + " is_find = true;\n" + " break;\n" + " }\n" + " }\n" + " if (!is_find) {\n" + " MICRO_ERROR(\"invalid feature param: %s\", src->name);\n" + " return RET_ERROR;\n" + " }\n" + " if (src->elenums != dst.elenums) {\n" + " MICRO_ERROR(\"feature %s elenums is mismatch, src: %lu, dst: %lu\", src->name, src->elenums, " + "dst.elenums);\n" + " return RET_ERROR;\n" + " }\n" + " memcpy(dst.data, src->data, src->elenums * sizeof(float));\n" + " }\n" + " MICRO_INFO(\"update features map success\");\n" + " return RET_OK;\n" + "}\n\n"; +} + +void CodeTrainState(std::ofstream &ofs, const std::string &module_name) { + ofs << "/**\n" + " * Train Function\n" + " * @param epoch, the train epoch\n" + " * @param iterations, which is equal to batch_num, the number of iterations of each epoch\n" + " * @param use_train_param, default parameters already exists, such as the momentum, user can update these\n" + " * parameters to improve the accuracy\n" + " * @param parameter, the TrainParameter contains epsilon/beta1/beta2\n" + " * @return status\n" + " */\n" + << "int " << module_name + << "_Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter *parameter, " + "const struct EarlyStop *early_stop);\n\n"; +} + +void CodeTrainImplement(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr &ctx) { + std::vector inputs = ctx->graph_inputs(); + size_t inputs_num = inputs.size(); + auto inputs_tostring = [&]() { + std::string result; + result += "{"; + for (size_t i = 0; i < inputs.size(); ++i) { + result += ctx->input_name() + std::to_string(i) + ", "; + } + result += "}"; + return result; + }; + auto wrap = [](int i) { return "[" + std::to_string(i) + "]"; }; + auto offset_inputs = [&]() { + std::string src = "origin_inputs"; + std::string dst = "input_ptr"; + std::string result; + for (size_t i = 0; i < inputs.size(); ++i) { + result += dst + wrap(i) += " = " + src + wrap(i) + " + j * " + std::to_string(inputs[i]->Size()) + ";\n"; + } + return result; + }; + auto varify_inputs = [&]() { + std::string result; + for (size_t i = 0; i < inputs.size(); ++i) { + result += "origin_input" + wrap(i) + " + iterations * " + std::to_string(inputs[i]->Size()) + " == NULL"; + i < inputs.size() - 1 ? result += " || " : result += ""; + } + return result; + }; + ofs << "int " << module_name + << "_Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter " + "*parameter, const struct EarlyStop *early_stop) {\n" + " if (iterations <= 0 || epoch <= 0) {\n" + " MICRO_ERROR(\"error iterations or epoch!, epoch:%d, iterations:%d\", epoch, iterations);\n" + " return RET_ERROR;\n" + " }\n" + " MICRO_INFO(\"train epoch: %d, batch_num: %d\", epoch, iterations);\n" + << " const void *origin_input[] = " << inputs_tostring() << ";\n"; + ofs << " if (" << varify_inputs() << ") {\n" + << " MICRO_ERROR(\"input data is invalid, epoch: %d, iterations: %d\", epoch, iterations);\n" + " return RET_ERROR;\n" + " }\n"; + ofs << " for (int i = 0; i < epoch; ++i) {\n" + << " const void *input_ptr[" << inputs_num << "];\n" + << " float loss = 0;\n" + << " for (int j = 0; j < iterations; ++j) {\n" + << " " << offset_inputs() << "\n" + << " " << module_name << "_SetInputs(input_ptr, " << inputs_num << ");\n" + << " " << module_name << "_Inference();\n" + << " loss = " << module_name << "_ComputeLossAndGradient();\n" + << " }\n" + " }\n" + " return RET_OK;\n" + "};\n\n"; +} +} // namespace mindspore::lite::micro diff --git a/mindspore/lite/micro/coder/generator/component/train_component.h b/mindspore/lite/micro/coder/generator/component/train_component.h new file mode 100644 index 00000000000..c3c00fe4d4e --- /dev/null +++ b/mindspore/lite/micro/coder/generator/component/train_component.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_ +#define MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_ + +#include +#include +#include +#include +#include +#include "src/tensor.h" +#include "coder/context.h" + +namespace mindspore::lite::micro { +void CodeTrainParams(std::ofstream &ofs); + +void CodeFeaturesState(std::ofstream &ofs, const std::string &module_name); +void CodeFeaturesImplement(std::ofstream &ofs, const std::string &module_name, + const std::unique_ptr &ctx); + +void CodeTrainState(std::ofstream &ofs, const std::string &module_name); +void CodeTrainImplement(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr &ctx); +} // namespace mindspore::lite::micro + +#endif // MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_ diff --git a/mindspore/lite/micro/coder/generator/generator.cc b/mindspore/lite/micro/coder/generator/generator.cc index 0be8854b26d..7a8cf54d33d 100644 --- a/mindspore/lite/micro/coder/generator/generator.cc +++ b/mindspore/lite/micro/coder/generator/generator.cc @@ -59,15 +59,13 @@ Generator::Generator(std::unique_ptr ctx) { Generator::~Generator() { (void)umask(origin_umask_); } void Generator::CodeNetRunFunc(std::ofstream &ofs) { - // generate net predict code + // generate net inference code ofs << "void " << config_->module_name() << "_Inference() {\n"; - if (config_->code_mode() == CodeMode::Code_Android) { + if (config_->code_mode() == CodeMode::Code_Inference) { ofs << "int thread_num = GetCurrentThreadNum(THREAD_POOL_DEFAULT);\n"; } - for (const auto &codeBlock : ctx_->code_blocks()) { - ofs << "\t{\n"; - ofs << codeBlock; - ofs << "\t}\n"; + for (const auto &block : ctx_->code_blocks()) { + ofs << "\t{\n" << block << "\t}\n"; } ofs << "}\n"; } diff --git a/mindspore/lite/micro/coder/generator/inference/inference_generator.cc b/mindspore/lite/micro/coder/generator/inference/inference_generator.cc index 33acd925cf5..cfa4c578304 100644 --- a/mindspore/lite/micro/coder/generator/inference/inference_generator.cc +++ b/mindspore/lite/micro/coder/generator/inference/inference_generator.cc @@ -28,7 +28,7 @@ int InferenceGenerator::CodeNetHFile() { MS_CHECK_TRUE(!ofs.bad(), "filed to open file"); MS_LOG(INFO) << "write " << net_include_file; ofs << g_hwLicense; - if (config_->code_mode() == CodeMode::Code_Android) { + if (config_->code_mode() == CodeMode::Code_Inference) { ofs << "#include \"src/runtime/thread_pool.h\"\n"; } ofs << "#include \"microtensor.h\"\n\n"; @@ -78,7 +78,7 @@ int InferenceGenerator::CodeBenchmarkFile() { if (config_->is_weight_file()) { CodeBenchmarkInitWeight(ofs, config_->module_name()); } - if (config_->code_mode() == CodeMode::Code_Android) { + if (config_->code_mode() == CodeMode::Code_Inference) { CodeBenchmarkConfigThread(ofs); } CodeBenchmarkInference(ofs, config_->module_name()); diff --git a/mindspore/lite/micro/coder/generator/train/train_generator.cc b/mindspore/lite/micro/coder/generator/train/train_generator.cc new file mode 100644 index 00000000000..b8aa287098a --- /dev/null +++ b/mindspore/lite/micro/coder/generator/train/train_generator.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "coder/generator/train/train_generator.h" +#include +#include +#include "coder/generator/component/common_component.h" +#include "coder/generator/component/benchmark_component.h" +#include "coder/generator/component/train_component.h" +#include "coder/generator/component/const_blocks/license.h" + +namespace mindspore::lite::micro { +void TrainGenerator::CodeGradientFunc(std::ofstream &ofs) const { + ofs << "float " << config_->module_name() << "_ComputeLossAndGradient() {\n"; + ofs << " float loss = 0;\n"; + for (const auto &block : ctx_->train_blocks()) { + ofs << " {\n" << block << " }\n"; + } + ofs << " return loss;\n"; + ofs << "}\n"; +} + +int TrainGenerator::CodeNetHFile() { + std::string net_include_file = net_inc_file_path_ + net_inc_hfile_; + std::ofstream ofs(net_include_file); + MS_CHECK_TRUE(!ofs.bad(), "filed to open file"); + MS_LOG(INFO) << "write " << net_include_file; + ofs << g_hwLicense; + if (config_->code_mode() == CodeMode::Code_Inference) { + ofs << "#include \"src/runtime/thread_pool.h\"\n"; + } + ofs << "#include \"microtensor.h\"\n\n"; + CodeTrainParams(ofs); + CodeInputAndOutputState(ofs, config_->module_name()); + if (is_get_quant_args_) { + CodeGraphQuantArgsState(ofs, config_->module_name()); + } + if (config_->is_weight_file()) { + CodeInitWeightState(ofs, config_->module_name()); + } + CodeManageResourceState(ofs, config_->module_name()); + CodeInferenceState(ofs, config_->module_name()); + CodeFeaturesState(ofs, config_->module_name()); + CodeTrainState(ofs, config_->module_name()); + return RET_OK; +} + +int TrainGenerator::CodeNetCFile() { + std::string net_impl_file = net_src_file_path_ + net_src_cfile_; + std::ofstream ofs(net_impl_file); + MS_CHECK_TRUE(!ofs.bad(), "filed to open file"); + MS_LOG(INFO) << "write " << net_impl_file; + CodeSourceFileInclude(ofs, net_weight_hfile_, net_inc_hfile_); + CodeInputAndOutputImplement(ofs, config_->module_name(), ctx_); + CodeInitResourceImplement(ofs, config_->module_name(), ctx_); + CodeFreeResourceImplement(ofs, config_->module_name(), ctx_); + CodeFeaturesImplement(ofs, config_->module_name(), ctx_); + if (is_get_quant_args_) { + CodeGraphQuantArgsImplement(ofs, config_->module_name(), ctx_); + } + CodeNetRunFunc(ofs); + CodeGradientFunc(ofs); + CodeTrainImplement(ofs, config_->module_name(), ctx_); + ofs.close(); + return RET_OK; +} + +int TrainGenerator::CodeBenchmarkFile() { + std::string net_main_impl_file = net_main_file_path_ + net_main_cfile_; + std::ofstream ofs(net_main_impl_file); + MS_LOG(INFO) << "write " << net_main_impl_file; + MS_CHECK_TRUE(!ofs.bad(), "filed to open file"); + std::vector inputs = ctx_->graph_inputs(); + size_t inputs_num = inputs.size(); + + CodeBenchmarkHeader(ofs, net_inc_hfile_); + CodeBenchmarkUsage(ofs); + CodeBenchmarkWarmup(ofs, config_->module_name()); + + CodeBenchmarkSetInputs(ofs, config_->module_name(), ctx_); + CodeBenchmarkSetBuffer(ofs, config_->module_name()); + if (config_->is_weight_file()) { + CodeBenchmarkInitWeight(ofs, config_->module_name()); + } + if (config_->code_mode() == CodeMode::Code_Inference) { + CodeBenchmarkConfigThread(ofs); + } + CodeBenchmarkInference(ofs, config_->module_name()); + CodeBenchmarkPrintOutputs(ofs, config_->module_name()); + + CodeBenchmarkFreeResourse(ofs, config_->module_name(), inputs_num); + ofs.close(); + return RET_OK; +} +} // namespace mindspore::lite::micro diff --git a/mindspore/lite/micro/coder/generator/train/train_generator.h b/mindspore/lite/micro/coder/generator/train/train_generator.h new file mode 100644 index 00000000000..c18fd2b5b17 --- /dev/null +++ b/mindspore/lite/micro/coder/generator/train/train_generator.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_GENERATOR_H_ +#define MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_GENERATOR_H_ + +#include +#include +#include "micro/coder/generator/generator.h" + +namespace mindspore::lite::micro { +class TrainGenerator : public Generator { + public: + explicit TrainGenerator(std::unique_ptr ctx) : Generator(std::move(ctx)) {} + ~TrainGenerator() override = default; + + private: + int CodeNetHFile() override; + int CodeNetCFile() override; + + int CodeBenchmarkFile() override; + + void CodeGradientFunc(std::ofstream &ofs) const; +}; +} // namespace mindspore::lite::micro +#endif // MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_GENERATOR_H_ diff --git a/mindspore/lite/micro/coder/opcoders/op_coder.h b/mindspore/lite/micro/coder/opcoders/op_coder.h index ced4cb489a3..1560b79ddcf 100644 --- a/mindspore/lite/micro/coder/opcoders/op_coder.h +++ b/mindspore/lite/micro/coder/opcoders/op_coder.h @@ -28,8 +28,8 @@ #include "securec/include/securec.h" #include "coder/opcoders/op_coder_register.h" #include "coder/log.h" + namespace mindspore::lite::micro { -class CoderContext; constexpr int kPrecision = 19; #define CODE_PARALLEL_FUNC(func) code << "ParallelLaunch(THREAD_POOL_DEFAULT, " << func << ", &args, thread_num);\n" diff --git a/mindspore/lite/micro/coder/opcoders/op_coder_builder.cc b/mindspore/lite/micro/coder/opcoders/op_coder_builder.cc index 2511da37e83..106d6ddfaff 100644 --- a/mindspore/lite/micro/coder/opcoders/op_coder_builder.cc +++ b/mindspore/lite/micro/coder/opcoders/op_coder_builder.cc @@ -61,7 +61,7 @@ std::unique_ptr OpCoderBuilder::build() { } op_coder->set_input_tensor_indices(input_indices_); op_coder->set_output_tensor_indices(output_indices_); - int thread_num = this->mode_ == CodeMode::Code_Android ? kMAX_THREAD_NUM_SUPPORT : 1; + int thread_num = this->mode_ == CodeMode::Code_Inference ? kMAX_THREAD_NUM_SUPPORT : 1; op_coder->set_thread_num(thread_num); parameter->thread_num_ = thread_num; op_coder->set_parameter(parameter); diff --git a/mindspore/lite/micro/coder/session.cc b/mindspore/lite/micro/coder/session.cc index ab4a8617f0d..777502a83ec 100644 --- a/mindspore/lite/micro/coder/session.cc +++ b/mindspore/lite/micro/coder/session.cc @@ -16,13 +16,14 @@ #include "coder/session.h" #include -#include #include #include -#include "coder/allocator/allocator.h" #include "coder/context.h" +#include "coder/train.h" +#include "coder/allocator/allocator.h" #include "coder/generator/generator.h" #include "coder/generator/inference/inference_generator.h" +#include "coder/generator/train/train_generator.h" #include "coder/opcoders/op_coder_builder.h" #include "coder/utils/coder_utils.h" #include "coder/log.h" @@ -89,6 +90,9 @@ void CoderSession::EndCode() { blocks = AddDumpDataInfo(context_->code_blocks(), op_coders_); context_->set_code_blocks(blocks); } + if (config->code_mode() == Code_Train) { + Train::TransformGraphForTrain(context_.get(), op_coders_); + } } int CoderSession::Run() { @@ -123,10 +127,14 @@ int CoderSession::GenerateCode() { CodeMode code_mode = config->code_mode(); switch (code_mode) { case Code_Normal: - case Code_Android: - MS_LOG(INFO) << "generate code for Android"; + case Code_Inference: + MS_LOG(INFO) << "generate code for Inference"; generator = std::make_shared(std::move(context_)); break; + case Code_Train: + MS_LOG(INFO) << "generate code for Inference"; + generator = std::make_shared(std::move(context_)); + break; default: MS_LOG(ERROR) << "unsupported generator code mode, " << code_mode; return RET_ERROR; diff --git a/mindspore/lite/micro/coder/train.cc b/mindspore/lite/micro/coder/train.cc new file mode 100644 index 00000000000..b2bfb07993d --- /dev/null +++ b/mindspore/lite/micro/coder/train.cc @@ -0,0 +1,95 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "coder/train.h" +#include +#include +#include +#include +#include +#include + +namespace mindspore::lite::micro { + +std::set FindInferenceOpcoders(OperatorCoder *edge) { + std::set subgraph; + std::queue to_visit; + to_visit.push(edge); + while (!to_visit.empty()) { + size_t size = to_visit.size(); + for (size_t i = 0; i < size; ++i) { + OperatorCoder *curr = to_visit.front(); + to_visit.pop(); + if (subgraph.find(curr) != subgraph.end()) { + continue; + } + subgraph.insert(curr); + for (const auto &op : curr->input_ops()) { + to_visit.push(op); + } + } + } + auto item = subgraph.find(edge); + if (item == subgraph.end()) { + MS_LOG(ERROR) << "failed to find the edge in the subgraph"; + return subgraph; + } + // erase edge operator coder from subgraph + subgraph.erase(item); + return subgraph; +} + +int Train::TransformGraphForTrain(CoderContext *context, const std::vector> &op_coders) { + const std::set loss_types = {schema::PrimitiveType_SoftmaxCrossEntropy, + schema::PrimitiveType_SparseSoftmaxCrossEntropy, + schema::PrimitiveType_BinaryCrossEntropy, + schema::PrimitiveType_SmoothL1Loss, + schema::PrimitiveType_SmoothL1LossGrad, + schema::PrimitiveType_SigmoidCrossEntropyWithLogits, + schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad}; + OperatorCoder *loss_op = nullptr; + for (const auto &opcoder : op_coders) { + auto primitive_type = static_cast(opcoder->primitive()->Type()); + auto item = loss_types.find(primitive_type); + if (item != loss_types.end()) { + loss_op = opcoder.get(); + break; + } + } + MS_CHECK_PTR(loss_op); + size_t op_num = op_coders.size(); + std::vector code_blocks = context->code_blocks(); + if (op_num != code_blocks.size()) { + MS_LOG(INFO) << "the number of code blocks and op coders is not equal"; + return RET_ERROR; + } + std::set inference_ops = FindInferenceOpcoders(loss_op); + std::vector inferences_blocks; + std::vector train_blocks; + for (size_t i = 0; i < op_num; ++i) { + auto &opcoder = op_coders.at(i); + std::string block = code_blocks.at(i); + if (inference_ops.find(opcoder.get()) != inference_ops.end()) { + inferences_blocks.push_back(block); + } + train_blocks.push_back(block); + } + context->set_inference_blocks(inferences_blocks); + context->set_train_blocks(train_blocks); + return RET_OK; +} + +} // namespace mindspore::lite::micro diff --git a/mindspore/lite/micro/coder/train.h b/mindspore/lite/micro/coder/train.h new file mode 100644 index 00000000000..fe335e6dd16 --- /dev/null +++ b/mindspore/lite/micro/coder/train.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_MICRO_CODER_CODER_TRAIN_H_ +#define MINDSPORE_LITE_MICRO_CODER_CODER_TRAIN_H_ + +#include +#include +#include "coder/context.h" +#include "coder/opcoders/op_coder.h" + +namespace mindspore::lite::micro { +class Train { + public: + static int TransformGraphForTrain(CoderContext *context, + const std::vector> &op_coders); +}; + +} // namespace mindspore::lite::micro +#endif // MINDSPORE_LITE_MICRO_CODER_CODER_TRAIN_H_ diff --git a/mindspore/lite/micro/coder/utils/coder_utils.cc b/mindspore/lite/micro/coder/utils/coder_utils.cc index 731d17ff5f8..55a92e49c06 100644 --- a/mindspore/lite/micro/coder/utils/coder_utils.cc +++ b/mindspore/lite/micro/coder/utils/coder_utils.cc @@ -142,32 +142,4 @@ std::vector SplitString(std::string str, const std::string &pattern } return results; } - -std::set FindInferenceOpcoders(OperatorCoder *edge) { - std::set subgraph; - std::queue to_visit; - to_visit.push(edge); - while (!to_visit.empty()) { - size_t size = to_visit.size(); - for (size_t i = 0; i < size; ++i) { - OperatorCoder *curr = to_visit.front(); - to_visit.pop(); - if (subgraph.find(curr) != subgraph.end()) { - continue; - } - subgraph.insert(curr); - for (const auto &op : curr->input_ops()) { - to_visit.push(op); - } - } - } - auto item = subgraph.find(edge); - if (item == subgraph.end()) { - MS_LOG(ERROR) << "failed to find the edge in the subgraph"; - return subgraph; - } - // erase edge operator coder from subgraph - subgraph.erase(item); - return subgraph; -} } // namespace mindspore::lite::micro diff --git a/mindspore/lite/micro/coder/utils/coder_utils.h b/mindspore/lite/micro/coder/utils/coder_utils.h index 2a8e0588567..04f94972156 100644 --- a/mindspore/lite/micro/coder/utils/coder_utils.h +++ b/mindspore/lite/micro/coder/utils/coder_utils.h @@ -35,8 +35,6 @@ std::vector AddDumpDataInfo(const std::vector &blocks, void PrintTensorData(const lite::Tensor *tensor, std::ofstream &ofs); -std::set FindInferenceOpcoders(OperatorCoder *edge); - } // namespace mindspore::lite::micro #endif // MINDSPORE_LITE_MICRO_CODER_UTILS_CODER_UTILS_H_