!12182 [MSLITE]feature, micro support train

From: @yangjie159
Reviewed-by: @wangchengyuan,@HilbertDavid
Signed-off-by: @wangchengyuan
This commit is contained in:
mindspore-ci-bot 2021-02-18 17:03:29 +08:00 committed by Gitee
commit 9d2e07ae24
16 changed files with 521 additions and 47 deletions

View File

@ -4,6 +4,7 @@ set(CODER_SRC
${MICRO_DIR}/coder/context.cc
${MICRO_DIR}/coder/graph.cc
${MICRO_DIR}/coder/session.cc
${MICRO_DIR}/coder/train.cc
)
set(CODER_ALLOCATOR_SRC
@ -14,10 +15,12 @@ set(CODER_ALLOCATOR_SRC
set(CODER_GENERATOR_SRC
${MICRO_DIR}/coder/generator/generator.cc
${MICRO_DIR}/coder/generator/inference/inference_generator.cc
${MICRO_DIR}/coder/generator/train/train_generator.cc
${MICRO_DIR}/coder/generator/component/benchmark_component.cc
${MICRO_DIR}/coder/generator/component/common_component.cc
${MICRO_DIR}/coder/generator/component/weight_component.cc
${MICRO_DIR}/coder/generator/component/cmake_component.cc
${MICRO_DIR}/coder/generator/component/train_component.cc
)
set(CODER_OPCODERS_SRC

View File

@ -39,7 +39,7 @@ class CoderFlags : public virtual FlagParser {
AddFlag(&CoderFlags::code_path_, "codePath", "Input code path", ".");
AddFlag(&CoderFlags::code_module_name_, "moduleName", "Input code module name", "");
AddFlag(&CoderFlags::target_, "target", "generateed code target, x86| ARM32M| ARM32A| ARM64", "x86");
AddFlag(&CoderFlags::code_mode_, "codeMode", "generated code mode, Normal | Android ", "Normal");
AddFlag(&CoderFlags::code_mode_, "codeMode", "generated code mode, Normal | Inference | Train", "Normal");
AddFlag(&CoderFlags::debug_mode_, "debugMode", "dump perlayer's time cost and tensor, true | false", false);
}
@ -87,7 +87,8 @@ int Coder::Run(const std::string &model_path) {
int Coder::Init(const CoderFlags &flags) const {
static const std::map<std::string, Target> kTargetMap = {
{"x86", kX86}, {"ARM32M", kARM32M}, {"ARM32A", kARM32A}, {"ARM64", kARM64}, {"All", kAllTargets}};
static const std::map<std::string, CodeMode> kCodeModeMap = {{"Normal", Code_Normal}, {"Android", Code_Android}};
static const std::map<std::string, CodeMode> kCodeModeMap = {
{"Normal", Code_Normal}, {"Inference", Code_Inference}, {"Train", Code_Train}};
Configurator *config = Configurator::GetInstance();

View File

@ -21,7 +21,7 @@
namespace mindspore::lite::micro {
enum Target { kX86 = 0, kARM32M = 1, kARM32A = 2, kARM64 = 3, kAllTargets = 4, kTargetUnknown = 99 };
enum CodeMode { Code_Normal = 0, Code_Android = 1, Code_Unknown = 99 };
enum CodeMode { Code_Normal = 0, Code_Inference = 1, Code_Train = 2, Code_Unknown = 99 };
class Configurator {
public:

View File

@ -0,0 +1,180 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "coder/generator/component/train_component.h"
#include <string>
#include "coder/utils/type_cast.h"
namespace mindspore::lite::micro {
void CodeTrainParams(std::ofstream &ofs) {
ofs << "struct TrainParameter {\n"
" float beta1_;\n"
" float beta2_;\n"
" float epsilon_;\n"
"};\n"
"\n"
"enum EarlyStopType {\n"
" Diff = 0,\n"
" WeigthDiff = 1,\n"
" Abs = 2,\n"
"};\n"
"\n"
"struct EarlyStop {\n"
" enum EarlyStopType type;\n"
" float tolerate;\n"
"};\n\n";
}
void CodeFeaturesState(std::ofstream &ofs, const std::string &module_name) {
ofs << "/**\n"
" *\n"
" * @param size, return the number of features\n"
" * @return, the address of features\n"
" */\n"
<< "FeatureParam *" << module_name << "_GetFeatures(int *size);\n\n";
ofs << "/**\n"
" *\n"
" * @param features, the address of features\n"
" * @param size, the number of features\n"
" * @return, status\n"
" */\n"
<< "int " << module_name << "_UpdateFeatures(FeatureParam *features, int size);\n\n";
}
void CodeFeaturesImplement(std::ofstream &ofs, const std::string &module_name,
const std::unique_ptr<CoderContext> &ctx) {
size_t features_num = 0;
ofs << "static FeatureParam feature_params[] = {\n";
for (const auto &item : ctx->saved_weights()) {
std::string addr = item.first;
Tensor *tensor = item.second;
if (tensor->tensor_name().empty()) {
MS_LOG(ERROR) << "exist empty feature";
continue;
}
ofs << "\t{\"" << tensor->tensor_name() << "\", " << addr << ", " << tensor->ElementsNum() << ", "
<< EnumMicroTensorDataType(tensor->data_type()) << "}, \n";
features_num++;
}
ofs << "};\n";
ofs << "FeatureParam *" << module_name << "_GetFeatures(int *size) {\n"
<< " *size = " << features_num << ";\n"
<< " return feature_params;\n"
"}\n\n";
ofs << "int " << module_name << "_UpdateFeatures(FeatureParam *features, int size) {\n"
<< " for (int i = 0; i < size; ++i) {\n"
" FeatureParam *src = features + i;\n"
" FeatureParam dst;\n"
" // find the dst feature\n"
" bool is_find = false;\n"
<< " for (int j = 0; j < " << features_num << "; ++j) {\n"
<< " if (strcmp(src->name, feature_params[j].name) == 0) {\n"
" dst = feature_params[j];\n"
" is_find = true;\n"
" break;\n"
" }\n"
" }\n"
" if (!is_find) {\n"
" MICRO_ERROR(\"invalid feature param: %s\", src->name);\n"
" return RET_ERROR;\n"
" }\n"
" if (src->elenums != dst.elenums) {\n"
" MICRO_ERROR(\"feature %s elenums is mismatch, src: %lu, dst: %lu\", src->name, src->elenums, "
"dst.elenums);\n"
" return RET_ERROR;\n"
" }\n"
" memcpy(dst.data, src->data, src->elenums * sizeof(float));\n"
" }\n"
" MICRO_INFO(\"update features map success\");\n"
" return RET_OK;\n"
"}\n\n";
}
void CodeTrainState(std::ofstream &ofs, const std::string &module_name) {
ofs << "/**\n"
" * Train Function\n"
" * @param epoch, the train epoch\n"
" * @param iterations, which is equal to batch_num, the number of iterations of each epoch\n"
" * @param use_train_param, default parameters already exists, such as the momentum, user can update these\n"
" * parameters to improve the accuracy\n"
" * @param parameter, the TrainParameter contains epsilon/beta1/beta2\n"
" * @return status\n"
" */\n"
<< "int " << module_name
<< "_Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter *parameter, "
"const struct EarlyStop *early_stop);\n\n";
}
void CodeTrainImplement(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr<CoderContext> &ctx) {
std::vector<Tensor *> inputs = ctx->graph_inputs();
size_t inputs_num = inputs.size();
auto inputs_tostring = [&]() {
std::string result;
result += "{";
for (size_t i = 0; i < inputs.size(); ++i) {
result += ctx->input_name() + std::to_string(i) + ", ";
}
result += "}";
return result;
};
auto wrap = [](int i) { return "[" + std::to_string(i) + "]"; };
auto offset_inputs = [&]() {
std::string src = "origin_inputs";
std::string dst = "input_ptr";
std::string result;
for (size_t i = 0; i < inputs.size(); ++i) {
result += dst + wrap(i) += " = " + src + wrap(i) + " + j * " + std::to_string(inputs[i]->Size()) + ";\n";
}
return result;
};
auto varify_inputs = [&]() {
std::string result;
for (size_t i = 0; i < inputs.size(); ++i) {
result += "origin_input" + wrap(i) + " + iterations * " + std::to_string(inputs[i]->Size()) + " == NULL";
i < inputs.size() - 1 ? result += " || " : result += "";
}
return result;
};
ofs << "int " << module_name
<< "_Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter "
"*parameter, const struct EarlyStop *early_stop) {\n"
" if (iterations <= 0 || epoch <= 0) {\n"
" MICRO_ERROR(\"error iterations or epoch!, epoch:%d, iterations:%d\", epoch, iterations);\n"
" return RET_ERROR;\n"
" }\n"
" MICRO_INFO(\"train epoch: %d, batch_num: %d\", epoch, iterations);\n"
<< " const void *origin_input[] = " << inputs_tostring() << ";\n";
ofs << " if (" << varify_inputs() << ") {\n"
<< " MICRO_ERROR(\"input data is invalid, epoch: %d, iterations: %d\", epoch, iterations);\n"
" return RET_ERROR;\n"
" }\n";
ofs << " for (int i = 0; i < epoch; ++i) {\n"
<< " const void *input_ptr[" << inputs_num << "];\n"
<< " float loss = 0;\n"
<< " for (int j = 0; j < iterations; ++j) {\n"
<< " " << offset_inputs() << "\n"
<< " " << module_name << "_SetInputs(input_ptr, " << inputs_num << ");\n"
<< " " << module_name << "_Inference();\n"
<< " loss = " << module_name << "_ComputeLossAndGradient();\n"
<< " }\n"
" }\n"
" return RET_OK;\n"
"};\n\n";
}
} // namespace mindspore::lite::micro

View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_
#define MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <fstream>
#include "src/tensor.h"
#include "coder/context.h"
namespace mindspore::lite::micro {
void CodeTrainParams(std::ofstream &ofs);
void CodeFeaturesState(std::ofstream &ofs, const std::string &module_name);
void CodeFeaturesImplement(std::ofstream &ofs, const std::string &module_name,
const std::unique_ptr<CoderContext> &ctx);
void CodeTrainState(std::ofstream &ofs, const std::string &module_name);
void CodeTrainImplement(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr<CoderContext> &ctx);
} // namespace mindspore::lite::micro
#endif // MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_

View File

@ -59,15 +59,13 @@ Generator::Generator(std::unique_ptr<CoderContext> ctx) {
Generator::~Generator() { (void)umask(origin_umask_); }
void Generator::CodeNetRunFunc(std::ofstream &ofs) {
// generate net predict code
// generate net inference code
ofs << "void " << config_->module_name() << "_Inference() {\n";
if (config_->code_mode() == CodeMode::Code_Android) {
if (config_->code_mode() == CodeMode::Code_Inference) {
ofs << "int thread_num = GetCurrentThreadNum(THREAD_POOL_DEFAULT);\n";
}
for (const auto &codeBlock : ctx_->code_blocks()) {
ofs << "\t{\n";
ofs << codeBlock;
ofs << "\t}\n";
for (const auto &block : ctx_->code_blocks()) {
ofs << "\t{\n" << block << "\t}\n";
}
ofs << "}\n";
}

View File

@ -28,7 +28,7 @@ int InferenceGenerator::CodeNetHFile() {
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
MS_LOG(INFO) << "write " << net_include_file;
ofs << g_hwLicense;
if (config_->code_mode() == CodeMode::Code_Android) {
if (config_->code_mode() == CodeMode::Code_Inference) {
ofs << "#include \"src/runtime/thread_pool.h\"\n";
}
ofs << "#include \"microtensor.h\"\n\n";
@ -78,7 +78,7 @@ int InferenceGenerator::CodeBenchmarkFile() {
if (config_->is_weight_file()) {
CodeBenchmarkInitWeight(ofs, config_->module_name());
}
if (config_->code_mode() == CodeMode::Code_Android) {
if (config_->code_mode() == CodeMode::Code_Inference) {
CodeBenchmarkConfigThread(ofs);
}
CodeBenchmarkInference(ofs, config_->module_name());

View File

@ -0,0 +1,108 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "coder/generator/train/train_generator.h"
#include <vector>
#include <string>
#include "coder/generator/component/common_component.h"
#include "coder/generator/component/benchmark_component.h"
#include "coder/generator/component/train_component.h"
#include "coder/generator/component/const_blocks/license.h"
namespace mindspore::lite::micro {
void TrainGenerator::CodeGradientFunc(std::ofstream &ofs) const {
ofs << "float " << config_->module_name() << "_ComputeLossAndGradient() {\n";
ofs << " float loss = 0;\n";
for (const auto &block : ctx_->train_blocks()) {
ofs << " {\n" << block << " }\n";
}
ofs << " return loss;\n";
ofs << "}\n";
}
int TrainGenerator::CodeNetHFile() {
std::string net_include_file = net_inc_file_path_ + net_inc_hfile_;
std::ofstream ofs(net_include_file);
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
MS_LOG(INFO) << "write " << net_include_file;
ofs << g_hwLicense;
if (config_->code_mode() == CodeMode::Code_Inference) {
ofs << "#include \"src/runtime/thread_pool.h\"\n";
}
ofs << "#include \"microtensor.h\"\n\n";
CodeTrainParams(ofs);
CodeInputAndOutputState(ofs, config_->module_name());
if (is_get_quant_args_) {
CodeGraphQuantArgsState(ofs, config_->module_name());
}
if (config_->is_weight_file()) {
CodeInitWeightState(ofs, config_->module_name());
}
CodeManageResourceState(ofs, config_->module_name());
CodeInferenceState(ofs, config_->module_name());
CodeFeaturesState(ofs, config_->module_name());
CodeTrainState(ofs, config_->module_name());
return RET_OK;
}
int TrainGenerator::CodeNetCFile() {
std::string net_impl_file = net_src_file_path_ + net_src_cfile_;
std::ofstream ofs(net_impl_file);
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
MS_LOG(INFO) << "write " << net_impl_file;
CodeSourceFileInclude(ofs, net_weight_hfile_, net_inc_hfile_);
CodeInputAndOutputImplement(ofs, config_->module_name(), ctx_);
CodeInitResourceImplement(ofs, config_->module_name(), ctx_);
CodeFreeResourceImplement(ofs, config_->module_name(), ctx_);
CodeFeaturesImplement(ofs, config_->module_name(), ctx_);
if (is_get_quant_args_) {
CodeGraphQuantArgsImplement(ofs, config_->module_name(), ctx_);
}
CodeNetRunFunc(ofs);
CodeGradientFunc(ofs);
CodeTrainImplement(ofs, config_->module_name(), ctx_);
ofs.close();
return RET_OK;
}
int TrainGenerator::CodeBenchmarkFile() {
std::string net_main_impl_file = net_main_file_path_ + net_main_cfile_;
std::ofstream ofs(net_main_impl_file);
MS_LOG(INFO) << "write " << net_main_impl_file;
MS_CHECK_TRUE(!ofs.bad(), "filed to open file");
std::vector<Tensor *> inputs = ctx_->graph_inputs();
size_t inputs_num = inputs.size();
CodeBenchmarkHeader(ofs, net_inc_hfile_);
CodeBenchmarkUsage(ofs);
CodeBenchmarkWarmup(ofs, config_->module_name());
CodeBenchmarkSetInputs(ofs, config_->module_name(), ctx_);
CodeBenchmarkSetBuffer(ofs, config_->module_name());
if (config_->is_weight_file()) {
CodeBenchmarkInitWeight(ofs, config_->module_name());
}
if (config_->code_mode() == CodeMode::Code_Inference) {
CodeBenchmarkConfigThread(ofs);
}
CodeBenchmarkInference(ofs, config_->module_name());
CodeBenchmarkPrintOutputs(ofs, config_->module_name());
CodeBenchmarkFreeResourse(ofs, config_->module_name(), inputs_num);
ofs.close();
return RET_OK;
}
} // namespace mindspore::lite::micro

View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_GENERATOR_H_
#define MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_GENERATOR_H_
#include <utility>
#include <memory>
#include "micro/coder/generator/generator.h"
namespace mindspore::lite::micro {
class TrainGenerator : public Generator {
public:
explicit TrainGenerator(std::unique_ptr<CoderContext> ctx) : Generator(std::move(ctx)) {}
~TrainGenerator() override = default;
private:
int CodeNetHFile() override;
int CodeNetCFile() override;
int CodeBenchmarkFile() override;
void CodeGradientFunc(std::ofstream &ofs) const;
};
} // namespace mindspore::lite::micro
#endif // MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_GENERATOR_H_

View File

@ -28,8 +28,8 @@
#include "securec/include/securec.h"
#include "coder/opcoders/op_coder_register.h"
#include "coder/log.h"
namespace mindspore::lite::micro {
class CoderContext;
constexpr int kPrecision = 19;
#define CODE_PARALLEL_FUNC(func) code << "ParallelLaunch(THREAD_POOL_DEFAULT, " << func << ", &args, thread_num);\n"

View File

@ -61,7 +61,7 @@ std::unique_ptr<OperatorCoder> OpCoderBuilder::build() {
}
op_coder->set_input_tensor_indices(input_indices_);
op_coder->set_output_tensor_indices(output_indices_);
int thread_num = this->mode_ == CodeMode::Code_Android ? kMAX_THREAD_NUM_SUPPORT : 1;
int thread_num = this->mode_ == CodeMode::Code_Inference ? kMAX_THREAD_NUM_SUPPORT : 1;
op_coder->set_thread_num(thread_num);
parameter->thread_num_ = thread_num;
op_coder->set_parameter(parameter);

View File

@ -16,13 +16,14 @@
#include "coder/session.h"
#include <set>
#include <queue>
#include <vector>
#include <utility>
#include "coder/allocator/allocator.h"
#include "coder/context.h"
#include "coder/train.h"
#include "coder/allocator/allocator.h"
#include "coder/generator/generator.h"
#include "coder/generator/inference/inference_generator.h"
#include "coder/generator/train/train_generator.h"
#include "coder/opcoders/op_coder_builder.h"
#include "coder/utils/coder_utils.h"
#include "coder/log.h"
@ -89,6 +90,9 @@ void CoderSession::EndCode() {
blocks = AddDumpDataInfo(context_->code_blocks(), op_coders_);
context_->set_code_blocks(blocks);
}
if (config->code_mode() == Code_Train) {
Train::TransformGraphForTrain(context_.get(), op_coders_);
}
}
int CoderSession::Run() {
@ -123,10 +127,14 @@ int CoderSession::GenerateCode() {
CodeMode code_mode = config->code_mode();
switch (code_mode) {
case Code_Normal:
case Code_Android:
MS_LOG(INFO) << "generate code for Android";
case Code_Inference:
MS_LOG(INFO) << "generate code for Inference";
generator = std::make_shared<InferenceGenerator>(std::move(context_));
break;
case Code_Train:
MS_LOG(INFO) << "generate code for Inference";
generator = std::make_shared<TrainGenerator>(std::move(context_));
break;
default:
MS_LOG(ERROR) << "unsupported generator code mode, " << code_mode;
return RET_ERROR;

View File

@ -0,0 +1,95 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "coder/train.h"
#include <memory>
#include <set>
#include <map>
#include <queue>
#include <string>
#include <vector>
namespace mindspore::lite::micro {
std::set<OperatorCoder *> FindInferenceOpcoders(OperatorCoder *edge) {
std::set<OperatorCoder *> subgraph;
std::queue<OperatorCoder *> to_visit;
to_visit.push(edge);
while (!to_visit.empty()) {
size_t size = to_visit.size();
for (size_t i = 0; i < size; ++i) {
OperatorCoder *curr = to_visit.front();
to_visit.pop();
if (subgraph.find(curr) != subgraph.end()) {
continue;
}
subgraph.insert(curr);
for (const auto &op : curr->input_ops()) {
to_visit.push(op);
}
}
}
auto item = subgraph.find(edge);
if (item == subgraph.end()) {
MS_LOG(ERROR) << "failed to find the edge in the subgraph";
return subgraph;
}
// erase edge operator coder from subgraph
subgraph.erase(item);
return subgraph;
}
int Train::TransformGraphForTrain(CoderContext *context, const std::vector<std::unique_ptr<OperatorCoder>> &op_coders) {
const std::set<schema::PrimitiveType> loss_types = {schema::PrimitiveType_SoftmaxCrossEntropy,
schema::PrimitiveType_SparseSoftmaxCrossEntropy,
schema::PrimitiveType_BinaryCrossEntropy,
schema::PrimitiveType_SmoothL1Loss,
schema::PrimitiveType_SmoothL1LossGrad,
schema::PrimitiveType_SigmoidCrossEntropyWithLogits,
schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad};
OperatorCoder *loss_op = nullptr;
for (const auto &opcoder : op_coders) {
auto primitive_type = static_cast<schema::PrimitiveType>(opcoder->primitive()->Type());
auto item = loss_types.find(primitive_type);
if (item != loss_types.end()) {
loss_op = opcoder.get();
break;
}
}
MS_CHECK_PTR(loss_op);
size_t op_num = op_coders.size();
std::vector<std::string> code_blocks = context->code_blocks();
if (op_num != code_blocks.size()) {
MS_LOG(INFO) << "the number of code blocks and op coders is not equal";
return RET_ERROR;
}
std::set<OperatorCoder *> inference_ops = FindInferenceOpcoders(loss_op);
std::vector<std::string> inferences_blocks;
std::vector<std::string> train_blocks;
for (size_t i = 0; i < op_num; ++i) {
auto &opcoder = op_coders.at(i);
std::string block = code_blocks.at(i);
if (inference_ops.find(opcoder.get()) != inference_ops.end()) {
inferences_blocks.push_back(block);
}
train_blocks.push_back(block);
}
context->set_inference_blocks(inferences_blocks);
context->set_train_blocks(train_blocks);
return RET_OK;
}
} // namespace mindspore::lite::micro

View File

@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_MICRO_CODER_CODER_TRAIN_H_
#define MINDSPORE_LITE_MICRO_CODER_CODER_TRAIN_H_
#include <memory>
#include <vector>
#include "coder/context.h"
#include "coder/opcoders/op_coder.h"
namespace mindspore::lite::micro {
class Train {
public:
static int TransformGraphForTrain(CoderContext *context,
const std::vector<std::unique_ptr<OperatorCoder>> &op_coders);
};
} // namespace mindspore::lite::micro
#endif // MINDSPORE_LITE_MICRO_CODER_CODER_TRAIN_H_

View File

@ -142,32 +142,4 @@ std::vector<std::string> SplitString(std::string str, const std::string &pattern
}
return results;
}
std::set<OperatorCoder *> FindInferenceOpcoders(OperatorCoder *edge) {
std::set<OperatorCoder *> subgraph;
std::queue<OperatorCoder *> to_visit;
to_visit.push(edge);
while (!to_visit.empty()) {
size_t size = to_visit.size();
for (size_t i = 0; i < size; ++i) {
OperatorCoder *curr = to_visit.front();
to_visit.pop();
if (subgraph.find(curr) != subgraph.end()) {
continue;
}
subgraph.insert(curr);
for (const auto &op : curr->input_ops()) {
to_visit.push(op);
}
}
}
auto item = subgraph.find(edge);
if (item == subgraph.end()) {
MS_LOG(ERROR) << "failed to find the edge in the subgraph";
return subgraph;
}
// erase edge operator coder from subgraph
subgraph.erase(item);
return subgraph;
}
} // namespace mindspore::lite::micro

View File

@ -35,8 +35,6 @@ std::vector<std::string> AddDumpDataInfo(const std::vector<std::string> &blocks,
void PrintTensorData(const lite::Tensor *tensor, std::ofstream &ofs);
std::set<OperatorCoder *> FindInferenceOpcoders(OperatorCoder *edge);
} // namespace mindspore::lite::micro
#endif // MINDSPORE_LITE_MICRO_CODER_UTILS_CODER_UTILS_H_