forked from mindspore-Ecosystem/mindspore
!12182 [MSLITE]feature, micro support train
From: @yangjie159 Reviewed-by: @wangchengyuan,@HilbertDavid Signed-off-by: @wangchengyuan
This commit is contained in:
commit
9d2e07ae24
|
@ -4,6 +4,7 @@ set(CODER_SRC
|
|||
${MICRO_DIR}/coder/context.cc
|
||||
${MICRO_DIR}/coder/graph.cc
|
||||
${MICRO_DIR}/coder/session.cc
|
||||
${MICRO_DIR}/coder/train.cc
|
||||
)
|
||||
|
||||
set(CODER_ALLOCATOR_SRC
|
||||
|
@ -14,10 +15,12 @@ set(CODER_ALLOCATOR_SRC
|
|||
set(CODER_GENERATOR_SRC
|
||||
${MICRO_DIR}/coder/generator/generator.cc
|
||||
${MICRO_DIR}/coder/generator/inference/inference_generator.cc
|
||||
${MICRO_DIR}/coder/generator/train/train_generator.cc
|
||||
${MICRO_DIR}/coder/generator/component/benchmark_component.cc
|
||||
${MICRO_DIR}/coder/generator/component/common_component.cc
|
||||
${MICRO_DIR}/coder/generator/component/weight_component.cc
|
||||
${MICRO_DIR}/coder/generator/component/cmake_component.cc
|
||||
${MICRO_DIR}/coder/generator/component/train_component.cc
|
||||
)
|
||||
|
||||
set(CODER_OPCODERS_SRC
|
||||
|
|
|
@ -39,7 +39,7 @@ class CoderFlags : public virtual FlagParser {
|
|||
AddFlag(&CoderFlags::code_path_, "codePath", "Input code path", ".");
|
||||
AddFlag(&CoderFlags::code_module_name_, "moduleName", "Input code module name", "");
|
||||
AddFlag(&CoderFlags::target_, "target", "generateed code target, x86| ARM32M| ARM32A| ARM64", "x86");
|
||||
AddFlag(&CoderFlags::code_mode_, "codeMode", "generated code mode, Normal | Android ", "Normal");
|
||||
AddFlag(&CoderFlags::code_mode_, "codeMode", "generated code mode, Normal | Inference | Train", "Normal");
|
||||
AddFlag(&CoderFlags::debug_mode_, "debugMode", "dump perlayer's time cost and tensor, true | false", false);
|
||||
}
|
||||
|
||||
|
@ -87,7 +87,8 @@ int Coder::Run(const std::string &model_path) {
|
|||
int Coder::Init(const CoderFlags &flags) const {
|
||||
static const std::map<std::string, Target> kTargetMap = {
|
||||
{"x86", kX86}, {"ARM32M", kARM32M}, {"ARM32A", kARM32A}, {"ARM64", kARM64}, {"All", kAllTargets}};
|
||||
static const std::map<std::string, CodeMode> kCodeModeMap = {{"Normal", Code_Normal}, {"Android", Code_Android}};
|
||||
static const std::map<std::string, CodeMode> kCodeModeMap = {
|
||||
{"Normal", Code_Normal}, {"Inference", Code_Inference}, {"Train", Code_Train}};
|
||||
|
||||
Configurator *config = Configurator::GetInstance();
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
|
||||
namespace mindspore::lite::micro {
|
||||
enum Target { kX86 = 0, kARM32M = 1, kARM32A = 2, kARM64 = 3, kAllTargets = 4, kTargetUnknown = 99 };
|
||||
enum CodeMode { Code_Normal = 0, Code_Android = 1, Code_Unknown = 99 };
|
||||
enum CodeMode { Code_Normal = 0, Code_Inference = 1, Code_Train = 2, Code_Unknown = 99 };
|
||||
|
||||
class Configurator {
|
||||
public:
|
||||
|
|
|
@ -0,0 +1,180 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/generator/component/train_component.h"
|
||||
#include <string>
|
||||
#include "coder/utils/type_cast.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
|
||||
void CodeTrainParams(std::ofstream &ofs) {
|
||||
ofs << "struct TrainParameter {\n"
|
||||
" float beta1_;\n"
|
||||
" float beta2_;\n"
|
||||
" float epsilon_;\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"enum EarlyStopType {\n"
|
||||
" Diff = 0,\n"
|
||||
" WeigthDiff = 1,\n"
|
||||
" Abs = 2,\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"struct EarlyStop {\n"
|
||||
" enum EarlyStopType type;\n"
|
||||
" float tolerate;\n"
|
||||
"};\n\n";
|
||||
}
|
||||
|
||||
void CodeFeaturesState(std::ofstream &ofs, const std::string &module_name) {
|
||||
ofs << "/**\n"
|
||||
" *\n"
|
||||
" * @param size, return the number of features\n"
|
||||
" * @return, the address of features\n"
|
||||
" */\n"
|
||||
<< "FeatureParam *" << module_name << "_GetFeatures(int *size);\n\n";
|
||||
ofs << "/**\n"
|
||||
" *\n"
|
||||
" * @param features, the address of features\n"
|
||||
" * @param size, the number of features\n"
|
||||
" * @return, status\n"
|
||||
" */\n"
|
||||
<< "int " << module_name << "_UpdateFeatures(FeatureParam *features, int size);\n\n";
|
||||
}
|
||||
|
||||
void CodeFeaturesImplement(std::ofstream &ofs, const std::string &module_name,
|
||||
const std::unique_ptr<CoderContext> &ctx) {
|
||||
size_t features_num = 0;
|
||||
ofs << "static FeatureParam feature_params[] = {\n";
|
||||
for (const auto &item : ctx->saved_weights()) {
|
||||
std::string addr = item.first;
|
||||
Tensor *tensor = item.second;
|
||||
if (tensor->tensor_name().empty()) {
|
||||
MS_LOG(ERROR) << "exist empty feature";
|
||||
continue;
|
||||
}
|
||||
ofs << "\t{\"" << tensor->tensor_name() << "\", " << addr << ", " << tensor->ElementsNum() << ", "
|
||||
<< EnumMicroTensorDataType(tensor->data_type()) << "}, \n";
|
||||
features_num++;
|
||||
}
|
||||
ofs << "};\n";
|
||||
|
||||
ofs << "FeatureParam *" << module_name << "_GetFeatures(int *size) {\n"
|
||||
<< " *size = " << features_num << ";\n"
|
||||
<< " return feature_params;\n"
|
||||
"}\n\n";
|
||||
|
||||
ofs << "int " << module_name << "_UpdateFeatures(FeatureParam *features, int size) {\n"
|
||||
<< " for (int i = 0; i < size; ++i) {\n"
|
||||
" FeatureParam *src = features + i;\n"
|
||||
" FeatureParam dst;\n"
|
||||
" // find the dst feature\n"
|
||||
" bool is_find = false;\n"
|
||||
<< " for (int j = 0; j < " << features_num << "; ++j) {\n"
|
||||
<< " if (strcmp(src->name, feature_params[j].name) == 0) {\n"
|
||||
" dst = feature_params[j];\n"
|
||||
" is_find = true;\n"
|
||||
" break;\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" if (!is_find) {\n"
|
||||
" MICRO_ERROR(\"invalid feature param: %s\", src->name);\n"
|
||||
" return RET_ERROR;\n"
|
||||
" }\n"
|
||||
" if (src->elenums != dst.elenums) {\n"
|
||||
" MICRO_ERROR(\"feature %s elenums is mismatch, src: %lu, dst: %lu\", src->name, src->elenums, "
|
||||
"dst.elenums);\n"
|
||||
" return RET_ERROR;\n"
|
||||
" }\n"
|
||||
" memcpy(dst.data, src->data, src->elenums * sizeof(float));\n"
|
||||
" }\n"
|
||||
" MICRO_INFO(\"update features map success\");\n"
|
||||
" return RET_OK;\n"
|
||||
"}\n\n";
|
||||
}
|
||||
|
||||
void CodeTrainState(std::ofstream &ofs, const std::string &module_name) {
|
||||
ofs << "/**\n"
|
||||
" * Train Function\n"
|
||||
" * @param epoch, the train epoch\n"
|
||||
" * @param iterations, which is equal to batch_num, the number of iterations of each epoch\n"
|
||||
" * @param use_train_param, default parameters already exists, such as the momentum, user can update these\n"
|
||||
" * parameters to improve the accuracy\n"
|
||||
" * @param parameter, the TrainParameter contains epsilon/beta1/beta2\n"
|
||||
" * @return status\n"
|
||||
" */\n"
|
||||
<< "int " << module_name
|
||||
<< "_Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter *parameter, "
|
||||
"const struct EarlyStop *early_stop);\n\n";
|
||||
}
|
||||
|
||||
void CodeTrainImplement(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr<CoderContext> &ctx) {
|
||||
std::vector<Tensor *> inputs = ctx->graph_inputs();
|
||||
size_t inputs_num = inputs.size();
|
||||
auto inputs_tostring = [&]() {
|
||||
std::string result;
|
||||
result += "{";
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
result += ctx->input_name() + std::to_string(i) + ", ";
|
||||
}
|
||||
result += "}";
|
||||
return result;
|
||||
};
|
||||
auto wrap = [](int i) { return "[" + std::to_string(i) + "]"; };
|
||||
auto offset_inputs = [&]() {
|
||||
std::string src = "origin_inputs";
|
||||
std::string dst = "input_ptr";
|
||||
std::string result;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
result += dst + wrap(i) += " = " + src + wrap(i) + " + j * " + std::to_string(inputs[i]->Size()) + ";\n";
|
||||
}
|
||||
return result;
|
||||
};
|
||||
auto varify_inputs = [&]() {
|
||||
std::string result;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
result += "origin_input" + wrap(i) + " + iterations * " + std::to_string(inputs[i]->Size()) + " == NULL";
|
||||
i < inputs.size() - 1 ? result += " || " : result += "";
|
||||
}
|
||||
return result;
|
||||
};
|
||||
ofs << "int " << module_name
|
||||
<< "_Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter "
|
||||
"*parameter, const struct EarlyStop *early_stop) {\n"
|
||||
" if (iterations <= 0 || epoch <= 0) {\n"
|
||||
" MICRO_ERROR(\"error iterations or epoch!, epoch:%d, iterations:%d\", epoch, iterations);\n"
|
||||
" return RET_ERROR;\n"
|
||||
" }\n"
|
||||
" MICRO_INFO(\"train epoch: %d, batch_num: %d\", epoch, iterations);\n"
|
||||
<< " const void *origin_input[] = " << inputs_tostring() << ";\n";
|
||||
ofs << " if (" << varify_inputs() << ") {\n"
|
||||
<< " MICRO_ERROR(\"input data is invalid, epoch: %d, iterations: %d\", epoch, iterations);\n"
|
||||
" return RET_ERROR;\n"
|
||||
" }\n";
|
||||
ofs << " for (int i = 0; i < epoch; ++i) {\n"
|
||||
<< " const void *input_ptr[" << inputs_num << "];\n"
|
||||
<< " float loss = 0;\n"
|
||||
<< " for (int j = 0; j < iterations; ++j) {\n"
|
||||
<< " " << offset_inputs() << "\n"
|
||||
<< " " << module_name << "_SetInputs(input_ptr, " << inputs_num << ");\n"
|
||||
<< " " << module_name << "_Inference();\n"
|
||||
<< " loss = " << module_name << "_ComputeLossAndGradient();\n"
|
||||
<< " }\n"
|
||||
" }\n"
|
||||
" return RET_OK;\n"
|
||||
"};\n\n";
|
||||
}
|
||||
} // namespace mindspore::lite::micro
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_
|
||||
#define MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <fstream>
|
||||
#include "src/tensor.h"
|
||||
#include "coder/context.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
void CodeTrainParams(std::ofstream &ofs);
|
||||
|
||||
void CodeFeaturesState(std::ofstream &ofs, const std::string &module_name);
|
||||
void CodeFeaturesImplement(std::ofstream &ofs, const std::string &module_name,
|
||||
const std::unique_ptr<CoderContext> &ctx);
|
||||
|
||||
void CodeTrainState(std::ofstream &ofs, const std::string &module_name);
|
||||
void CodeTrainImplement(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr<CoderContext> &ctx);
|
||||
} // namespace mindspore::lite::micro
|
||||
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_
|
|
@ -59,15 +59,13 @@ Generator::Generator(std::unique_ptr<CoderContext> ctx) {
|
|||
Generator::~Generator() { (void)umask(origin_umask_); }
|
||||
|
||||
void Generator::CodeNetRunFunc(std::ofstream &ofs) {
|
||||
// generate net predict code
|
||||
// generate net inference code
|
||||
ofs << "void " << config_->module_name() << "_Inference() {\n";
|
||||
if (config_->code_mode() == CodeMode::Code_Android) {
|
||||
if (config_->code_mode() == CodeMode::Code_Inference) {
|
||||
ofs << "int thread_num = GetCurrentThreadNum(THREAD_POOL_DEFAULT);\n";
|
||||
}
|
||||
for (const auto &codeBlock : ctx_->code_blocks()) {
|
||||
ofs << "\t{\n";
|
||||
ofs << codeBlock;
|
||||
ofs << "\t}\n";
|
||||
for (const auto &block : ctx_->code_blocks()) {
|
||||
ofs << "\t{\n" << block << "\t}\n";
|
||||
}
|
||||
ofs << "}\n";
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ int InferenceGenerator::CodeNetHFile() {
|
|||
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
|
||||
MS_LOG(INFO) << "write " << net_include_file;
|
||||
ofs << g_hwLicense;
|
||||
if (config_->code_mode() == CodeMode::Code_Android) {
|
||||
if (config_->code_mode() == CodeMode::Code_Inference) {
|
||||
ofs << "#include \"src/runtime/thread_pool.h\"\n";
|
||||
}
|
||||
ofs << "#include \"microtensor.h\"\n\n";
|
||||
|
@ -78,7 +78,7 @@ int InferenceGenerator::CodeBenchmarkFile() {
|
|||
if (config_->is_weight_file()) {
|
||||
CodeBenchmarkInitWeight(ofs, config_->module_name());
|
||||
}
|
||||
if (config_->code_mode() == CodeMode::Code_Android) {
|
||||
if (config_->code_mode() == CodeMode::Code_Inference) {
|
||||
CodeBenchmarkConfigThread(ofs);
|
||||
}
|
||||
CodeBenchmarkInference(ofs, config_->module_name());
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/generator/train/train_generator.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "coder/generator/component/common_component.h"
|
||||
#include "coder/generator/component/benchmark_component.h"
|
||||
#include "coder/generator/component/train_component.h"
|
||||
#include "coder/generator/component/const_blocks/license.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
void TrainGenerator::CodeGradientFunc(std::ofstream &ofs) const {
|
||||
ofs << "float " << config_->module_name() << "_ComputeLossAndGradient() {\n";
|
||||
ofs << " float loss = 0;\n";
|
||||
for (const auto &block : ctx_->train_blocks()) {
|
||||
ofs << " {\n" << block << " }\n";
|
||||
}
|
||||
ofs << " return loss;\n";
|
||||
ofs << "}\n";
|
||||
}
|
||||
|
||||
int TrainGenerator::CodeNetHFile() {
|
||||
std::string net_include_file = net_inc_file_path_ + net_inc_hfile_;
|
||||
std::ofstream ofs(net_include_file);
|
||||
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
|
||||
MS_LOG(INFO) << "write " << net_include_file;
|
||||
ofs << g_hwLicense;
|
||||
if (config_->code_mode() == CodeMode::Code_Inference) {
|
||||
ofs << "#include \"src/runtime/thread_pool.h\"\n";
|
||||
}
|
||||
ofs << "#include \"microtensor.h\"\n\n";
|
||||
CodeTrainParams(ofs);
|
||||
CodeInputAndOutputState(ofs, config_->module_name());
|
||||
if (is_get_quant_args_) {
|
||||
CodeGraphQuantArgsState(ofs, config_->module_name());
|
||||
}
|
||||
if (config_->is_weight_file()) {
|
||||
CodeInitWeightState(ofs, config_->module_name());
|
||||
}
|
||||
CodeManageResourceState(ofs, config_->module_name());
|
||||
CodeInferenceState(ofs, config_->module_name());
|
||||
CodeFeaturesState(ofs, config_->module_name());
|
||||
CodeTrainState(ofs, config_->module_name());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TrainGenerator::CodeNetCFile() {
|
||||
std::string net_impl_file = net_src_file_path_ + net_src_cfile_;
|
||||
std::ofstream ofs(net_impl_file);
|
||||
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
|
||||
MS_LOG(INFO) << "write " << net_impl_file;
|
||||
CodeSourceFileInclude(ofs, net_weight_hfile_, net_inc_hfile_);
|
||||
CodeInputAndOutputImplement(ofs, config_->module_name(), ctx_);
|
||||
CodeInitResourceImplement(ofs, config_->module_name(), ctx_);
|
||||
CodeFreeResourceImplement(ofs, config_->module_name(), ctx_);
|
||||
CodeFeaturesImplement(ofs, config_->module_name(), ctx_);
|
||||
if (is_get_quant_args_) {
|
||||
CodeGraphQuantArgsImplement(ofs, config_->module_name(), ctx_);
|
||||
}
|
||||
CodeNetRunFunc(ofs);
|
||||
CodeGradientFunc(ofs);
|
||||
CodeTrainImplement(ofs, config_->module_name(), ctx_);
|
||||
ofs.close();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TrainGenerator::CodeBenchmarkFile() {
|
||||
std::string net_main_impl_file = net_main_file_path_ + net_main_cfile_;
|
||||
std::ofstream ofs(net_main_impl_file);
|
||||
MS_LOG(INFO) << "write " << net_main_impl_file;
|
||||
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
|
||||
std::vector<Tensor *> inputs = ctx_->graph_inputs();
|
||||
size_t inputs_num = inputs.size();
|
||||
|
||||
CodeBenchmarkHeader(ofs, net_inc_hfile_);
|
||||
CodeBenchmarkUsage(ofs);
|
||||
CodeBenchmarkWarmup(ofs, config_->module_name());
|
||||
|
||||
CodeBenchmarkSetInputs(ofs, config_->module_name(), ctx_);
|
||||
CodeBenchmarkSetBuffer(ofs, config_->module_name());
|
||||
if (config_->is_weight_file()) {
|
||||
CodeBenchmarkInitWeight(ofs, config_->module_name());
|
||||
}
|
||||
if (config_->code_mode() == CodeMode::Code_Inference) {
|
||||
CodeBenchmarkConfigThread(ofs);
|
||||
}
|
||||
CodeBenchmarkInference(ofs, config_->module_name());
|
||||
CodeBenchmarkPrintOutputs(ofs, config_->module_name());
|
||||
|
||||
CodeBenchmarkFreeResourse(ofs, config_->module_name(), inputs_num);
|
||||
ofs.close();
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite::micro
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_GENERATOR_H_
|
||||
#define MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_GENERATOR_H_
|
||||
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "micro/coder/generator/generator.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
class TrainGenerator : public Generator {
|
||||
public:
|
||||
explicit TrainGenerator(std::unique_ptr<CoderContext> ctx) : Generator(std::move(ctx)) {}
|
||||
~TrainGenerator() override = default;
|
||||
|
||||
private:
|
||||
int CodeNetHFile() override;
|
||||
int CodeNetCFile() override;
|
||||
|
||||
int CodeBenchmarkFile() override;
|
||||
|
||||
void CodeGradientFunc(std::ofstream &ofs) const;
|
||||
};
|
||||
} // namespace mindspore::lite::micro
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_GENERATOR_H_
|
|
@ -28,8 +28,8 @@
|
|||
#include "securec/include/securec.h"
|
||||
#include "coder/opcoders/op_coder_register.h"
|
||||
#include "coder/log.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
class CoderContext;
|
||||
constexpr int kPrecision = 19;
|
||||
|
||||
#define CODE_PARALLEL_FUNC(func) code << "ParallelLaunch(THREAD_POOL_DEFAULT, " << func << ", &args, thread_num);\n"
|
||||
|
|
|
@ -61,7 +61,7 @@ std::unique_ptr<OperatorCoder> OpCoderBuilder::build() {
|
|||
}
|
||||
op_coder->set_input_tensor_indices(input_indices_);
|
||||
op_coder->set_output_tensor_indices(output_indices_);
|
||||
int thread_num = this->mode_ == CodeMode::Code_Android ? kMAX_THREAD_NUM_SUPPORT : 1;
|
||||
int thread_num = this->mode_ == CodeMode::Code_Inference ? kMAX_THREAD_NUM_SUPPORT : 1;
|
||||
op_coder->set_thread_num(thread_num);
|
||||
parameter->thread_num_ = thread_num;
|
||||
op_coder->set_parameter(parameter);
|
||||
|
|
|
@ -16,13 +16,14 @@
|
|||
|
||||
#include "coder/session.h"
|
||||
#include <set>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "coder/allocator/allocator.h"
|
||||
#include "coder/context.h"
|
||||
#include "coder/train.h"
|
||||
#include "coder/allocator/allocator.h"
|
||||
#include "coder/generator/generator.h"
|
||||
#include "coder/generator/inference/inference_generator.h"
|
||||
#include "coder/generator/train/train_generator.h"
|
||||
#include "coder/opcoders/op_coder_builder.h"
|
||||
#include "coder/utils/coder_utils.h"
|
||||
#include "coder/log.h"
|
||||
|
@ -89,6 +90,9 @@ void CoderSession::EndCode() {
|
|||
blocks = AddDumpDataInfo(context_->code_blocks(), op_coders_);
|
||||
context_->set_code_blocks(blocks);
|
||||
}
|
||||
if (config->code_mode() == Code_Train) {
|
||||
Train::TransformGraphForTrain(context_.get(), op_coders_);
|
||||
}
|
||||
}
|
||||
|
||||
int CoderSession::Run() {
|
||||
|
@ -123,10 +127,14 @@ int CoderSession::GenerateCode() {
|
|||
CodeMode code_mode = config->code_mode();
|
||||
switch (code_mode) {
|
||||
case Code_Normal:
|
||||
case Code_Android:
|
||||
MS_LOG(INFO) << "generate code for Android";
|
||||
case Code_Inference:
|
||||
MS_LOG(INFO) << "generate code for Inference";
|
||||
generator = std::make_shared<InferenceGenerator>(std::move(context_));
|
||||
break;
|
||||
case Code_Train:
|
||||
MS_LOG(INFO) << "generate code for Inference";
|
||||
generator = std::make_shared<TrainGenerator>(std::move(context_));
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "unsupported generator code mode, " << code_mode;
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/train.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
|
||||
std::set<OperatorCoder *> FindInferenceOpcoders(OperatorCoder *edge) {
|
||||
std::set<OperatorCoder *> subgraph;
|
||||
std::queue<OperatorCoder *> to_visit;
|
||||
to_visit.push(edge);
|
||||
while (!to_visit.empty()) {
|
||||
size_t size = to_visit.size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
OperatorCoder *curr = to_visit.front();
|
||||
to_visit.pop();
|
||||
if (subgraph.find(curr) != subgraph.end()) {
|
||||
continue;
|
||||
}
|
||||
subgraph.insert(curr);
|
||||
for (const auto &op : curr->input_ops()) {
|
||||
to_visit.push(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto item = subgraph.find(edge);
|
||||
if (item == subgraph.end()) {
|
||||
MS_LOG(ERROR) << "failed to find the edge in the subgraph";
|
||||
return subgraph;
|
||||
}
|
||||
// erase edge operator coder from subgraph
|
||||
subgraph.erase(item);
|
||||
return subgraph;
|
||||
}
|
||||
|
||||
int Train::TransformGraphForTrain(CoderContext *context, const std::vector<std::unique_ptr<OperatorCoder>> &op_coders) {
|
||||
const std::set<schema::PrimitiveType> loss_types = {schema::PrimitiveType_SoftmaxCrossEntropy,
|
||||
schema::PrimitiveType_SparseSoftmaxCrossEntropy,
|
||||
schema::PrimitiveType_BinaryCrossEntropy,
|
||||
schema::PrimitiveType_SmoothL1Loss,
|
||||
schema::PrimitiveType_SmoothL1LossGrad,
|
||||
schema::PrimitiveType_SigmoidCrossEntropyWithLogits,
|
||||
schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad};
|
||||
OperatorCoder *loss_op = nullptr;
|
||||
for (const auto &opcoder : op_coders) {
|
||||
auto primitive_type = static_cast<schema::PrimitiveType>(opcoder->primitive()->Type());
|
||||
auto item = loss_types.find(primitive_type);
|
||||
if (item != loss_types.end()) {
|
||||
loss_op = opcoder.get();
|
||||
break;
|
||||
}
|
||||
}
|
||||
MS_CHECK_PTR(loss_op);
|
||||
size_t op_num = op_coders.size();
|
||||
std::vector<std::string> code_blocks = context->code_blocks();
|
||||
if (op_num != code_blocks.size()) {
|
||||
MS_LOG(INFO) << "the number of code blocks and op coders is not equal";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::set<OperatorCoder *> inference_ops = FindInferenceOpcoders(loss_op);
|
||||
std::vector<std::string> inferences_blocks;
|
||||
std::vector<std::string> train_blocks;
|
||||
for (size_t i = 0; i < op_num; ++i) {
|
||||
auto &opcoder = op_coders.at(i);
|
||||
std::string block = code_blocks.at(i);
|
||||
if (inference_ops.find(opcoder.get()) != inference_ops.end()) {
|
||||
inferences_blocks.push_back(block);
|
||||
}
|
||||
train_blocks.push_back(block);
|
||||
}
|
||||
context->set_inference_blocks(inferences_blocks);
|
||||
context->set_train_blocks(train_blocks);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite::micro
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_MICRO_CODER_CODER_TRAIN_H_
|
||||
#define MINDSPORE_LITE_MICRO_CODER_CODER_TRAIN_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "coder/context.h"
|
||||
#include "coder/opcoders/op_coder.h"
|
||||
|
||||
namespace mindspore::lite::micro {
|
||||
class Train {
|
||||
public:
|
||||
static int TransformGraphForTrain(CoderContext *context,
|
||||
const std::vector<std::unique_ptr<OperatorCoder>> &op_coders);
|
||||
};
|
||||
|
||||
} // namespace mindspore::lite::micro
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_CODER_TRAIN_H_
|
|
@ -142,32 +142,4 @@ std::vector<std::string> SplitString(std::string str, const std::string &pattern
|
|||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
std::set<OperatorCoder *> FindInferenceOpcoders(OperatorCoder *edge) {
|
||||
std::set<OperatorCoder *> subgraph;
|
||||
std::queue<OperatorCoder *> to_visit;
|
||||
to_visit.push(edge);
|
||||
while (!to_visit.empty()) {
|
||||
size_t size = to_visit.size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
OperatorCoder *curr = to_visit.front();
|
||||
to_visit.pop();
|
||||
if (subgraph.find(curr) != subgraph.end()) {
|
||||
continue;
|
||||
}
|
||||
subgraph.insert(curr);
|
||||
for (const auto &op : curr->input_ops()) {
|
||||
to_visit.push(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto item = subgraph.find(edge);
|
||||
if (item == subgraph.end()) {
|
||||
MS_LOG(ERROR) << "failed to find the edge in the subgraph";
|
||||
return subgraph;
|
||||
}
|
||||
// erase edge operator coder from subgraph
|
||||
subgraph.erase(item);
|
||||
return subgraph;
|
||||
}
|
||||
} // namespace mindspore::lite::micro
|
||||
|
|
|
@ -35,8 +35,6 @@ std::vector<std::string> AddDumpDataInfo(const std::vector<std::string> &blocks,
|
|||
|
||||
void PrintTensorData(const lite::Tensor *tensor, std::ofstream &ofs);
|
||||
|
||||
std::set<OperatorCoder *> FindInferenceOpcoders(OperatorCoder *edge);
|
||||
|
||||
} // namespace mindspore::lite::micro
|
||||
|
||||
#endif // MINDSPORE_LITE_MICRO_CODER_UTILS_CODER_UTILS_H_
|
||||
|
|
Loading…
Reference in New Issue