!13396 remove is_weight_file flag

From: @wangchengyuan
Reviewed-by: @zhanghaibo5,@hangangqiang
Signed-off-by: @zhanghaibo5
This commit is contained in:
mindspore-ci-bot 2021-03-16 19:56:32 +08:00 committed by Gitee
commit a39adfd010
7 changed files with 11 additions and 19 deletions

View File

@ -14,7 +14,7 @@ include(${TOP_DIR}/cmake/dependency_utils.cmake)
include(${TOP_DIR}/cmake/dependency_securec.cmake)
if(NOT PLATFORM_ARM64 AND NOT PLATFORM_ARM32)
include(${TOP_DIR}/cmake/external_libs/glog.cmake)
### flatbuffer
include(${TOP_DIR}/cmake/external_libs/flatbuffers.cmake)
file(GLOB FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/../schema/*.fbs)
ms_build_flatbuffers_lite(FBS_FILES

View File

@ -34,7 +34,6 @@ namespace mindspore::lite::micro {
class CoderFlags : public virtual FlagParser {
public:
CoderFlags() {
AddFlag(&CoderFlags::is_weight_file_, "isWeightFile", "whether generating weight binary file, true| false", false);
AddFlag(&CoderFlags::model_path_, "modelPath", "Input model path", "");
AddFlag(&CoderFlags::code_path_, "codePath", "Input code path", ".");
AddFlag(&CoderFlags::code_module_name_, "moduleName", "Input code module name", "");
@ -48,7 +47,6 @@ class CoderFlags : public virtual FlagParser {
std::string model_path_;
bool support_parallel_{false};
bool is_weight_file_{false};
std::string code_module_name_;
std::string code_path_;
std::string code_mode_;
@ -93,11 +91,6 @@ int Coder::Init(const CoderFlags &flags) const {
Configurator *config = Configurator::GetInstance();
std::vector<std::function<bool()>> parsers;
parsers.emplace_back([flags, config]() -> bool {
config->set_is_weight_file(flags.is_weight_file_);
return true;
});
parsers.emplace_back([&flags, config]() -> bool {
auto target_item = kTargetMap.find(flags.target_);
MS_CHECK_TRUE_RET_BOOL(target_item != kTargetMap.end(), "unsupported target: " + flags.target_);
@ -113,6 +106,10 @@ int Coder::Init(const CoderFlags &flags) const {
});
parsers.emplace_back([&flags, config]() -> bool {
if (flags.support_parallel_ == true && config->target() == kARM32M) {
MS_LOG(ERROR) << "arm32M cannot support parallel.";
return false;
}
config->set_support_parallel(flags.support_parallel_);
return true;
});
@ -175,7 +172,6 @@ int Coder::Init(const CoderFlags &flags) const {
print_parameter("codePath", config->code_path());
print_parameter("codeMode", config->code_mode());
print_parameter("codeModuleName", config->module_name());
print_parameter("isWeightFile", config->is_weight_file());
print_parameter("debugMode", config->debug_mode());
return RET_OK;

View File

@ -73,9 +73,6 @@ class Configurator {
void set_debug_mode(bool debug) { debug_mode_ = debug; }
bool debug_mode() const { return debug_mode_; }
void set_is_weight_file(bool flag) { is_weight_file_ = flag; }
bool is_weight_file() const { return is_weight_file_; }
void set_support_parallel(bool parallel) { support_parallel_ = parallel; }
bool support_parallel() const { return support_parallel_; }
@ -87,7 +84,6 @@ class Configurator {
std::string code_path_;
Target target_{kTargetUnknown};
CodeMode code_mode_{Code_Unknown};
bool is_weight_file_{false};
bool support_parallel_{false};
bool debug_mode_{false};
};

View File

@ -137,7 +137,7 @@ int Generator::CodeWeightFile() {
cofs << "#include \"" << net_weight_hfile_ << "\"\n\n";
cofs << "unsigned char * " << ctx_->buffer_name() << " = 0 ; \n";
if (config_->is_weight_file()) {
if (config_->target() != kARM32M) {
std::string net_file = net_src_file_path_ + config_->module_name() + ".net";
SaveDataToNet(ctx_->saved_weights(), net_file);
CodeModelParamsForNet(hofs, cofs, ctx_);

View File

@ -40,7 +40,7 @@ int InferenceGenerator::CodeNetHFile() {
if (config_->support_parallel()) {
CodeSetGlobalThreadPoolState(ofs, config_->module_name());
}
if (config_->is_weight_file()) {
if (config_->target() != kARM32M) {
CodeInitWeightState(ofs, config_->module_name());
}
CodeManageResourceState(ofs, config_->module_name());
@ -82,7 +82,7 @@ int InferenceGenerator::CodeBenchmarkFile() {
CodeBenchmarkSetInputs(ofs, config_->module_name(), ctx_);
CodeBenchmarkSetBuffer(ofs, config_->module_name());
if (config_->is_weight_file()) {
if (config_->target() != kARM32M) {
CodeBenchmarkInitWeight(ofs, config_->module_name());
}
if (config_->support_parallel()) {

View File

@ -45,7 +45,7 @@ int TrainGenerator::CodeNetHFile() {
ofs << "#include \"microtensor.h\"\n\n";
CodeTrainParams(ofs);
CodeInputAndOutputState(ofs, config_->module_name());
if (config_->is_weight_file()) {
if (config_->target() != kARM32M) {
CodeInitWeightState(ofs, config_->module_name());
}
CodeManageResourceState(ofs, config_->module_name());
@ -84,7 +84,7 @@ int TrainGenerator::CodeBenchmarkFile() {
CodeBenchmarkWarmup(ofs, config_->module_name());
CodeBenchmarkSetInputs(ofs, config_->module_name(), ctx_);
CodeBenchmarkSetBuffer(ofs, config_->module_name());
if (config_->is_weight_file()) {
if (config_->target() != kARM32M) {
CodeBenchmarkInitWeight(ofs, config_->module_name());
}
CodeBenchmarkInference(ofs, config_->module_name());

View File

@ -42,10 +42,10 @@ class OperatorCoder {
node_(node),
node_index_(node_index) {
allocator_ = MemoryAllocator::GetInstance();
// vectors checked not empty in OpCoderBuilder::build
input_tensor_ = input_tensors_.at(kInputIndex);
output_tensor_ = output_tensors_.at(kOutputIndex);
}
std::string name() const { return node_->name_; }
void set_input_tensor_indices(const std::vector<uint32_t> &input_indices);