forked from mindspore-Ecosystem/mindspore
!13396 remove is_weight_file flag
From: @wangchengyuan Reviewed-by: @zhanghaibo5,@hangangqiang Signed-off-by: @zhanghaibo5
This commit is contained in:
commit
a39adfd010
|
@ -14,7 +14,7 @@ include(${TOP_DIR}/cmake/dependency_utils.cmake)
|
|||
include(${TOP_DIR}/cmake/dependency_securec.cmake)
|
||||
if(NOT PLATFORM_ARM64 AND NOT PLATFORM_ARM32)
|
||||
include(${TOP_DIR}/cmake/external_libs/glog.cmake)
|
||||
|
||||
### flatbuffer
|
||||
include(${TOP_DIR}/cmake/external_libs/flatbuffers.cmake)
|
||||
file(GLOB FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/../schema/*.fbs)
|
||||
ms_build_flatbuffers_lite(FBS_FILES
|
||||
|
|
|
@ -34,7 +34,6 @@ namespace mindspore::lite::micro {
|
|||
class CoderFlags : public virtual FlagParser {
|
||||
public:
|
||||
CoderFlags() {
|
||||
AddFlag(&CoderFlags::is_weight_file_, "isWeightFile", "whether generating weight binary file, true| false", false);
|
||||
AddFlag(&CoderFlags::model_path_, "modelPath", "Input model path", "");
|
||||
AddFlag(&CoderFlags::code_path_, "codePath", "Input code path", ".");
|
||||
AddFlag(&CoderFlags::code_module_name_, "moduleName", "Input code module name", "");
|
||||
|
@ -48,7 +47,6 @@ class CoderFlags : public virtual FlagParser {
|
|||
|
||||
std::string model_path_;
|
||||
bool support_parallel_{false};
|
||||
bool is_weight_file_{false};
|
||||
std::string code_module_name_;
|
||||
std::string code_path_;
|
||||
std::string code_mode_;
|
||||
|
@ -93,11 +91,6 @@ int Coder::Init(const CoderFlags &flags) const {
|
|||
Configurator *config = Configurator::GetInstance();
|
||||
|
||||
std::vector<std::function<bool()>> parsers;
|
||||
parsers.emplace_back([flags, config]() -> bool {
|
||||
config->set_is_weight_file(flags.is_weight_file_);
|
||||
return true;
|
||||
});
|
||||
|
||||
parsers.emplace_back([&flags, config]() -> bool {
|
||||
auto target_item = kTargetMap.find(flags.target_);
|
||||
MS_CHECK_TRUE_RET_BOOL(target_item != kTargetMap.end(), "unsupported target: " + flags.target_);
|
||||
|
@ -113,6 +106,10 @@ int Coder::Init(const CoderFlags &flags) const {
|
|||
});
|
||||
|
||||
parsers.emplace_back([&flags, config]() -> bool {
|
||||
if (flags.support_parallel_ == true && config->target() == kARM32M) {
|
||||
MS_LOG(ERROR) << "arm32M cannot support parallel.";
|
||||
return false;
|
||||
}
|
||||
config->set_support_parallel(flags.support_parallel_);
|
||||
return true;
|
||||
});
|
||||
|
@ -175,7 +172,6 @@ int Coder::Init(const CoderFlags &flags) const {
|
|||
print_parameter("codePath", config->code_path());
|
||||
print_parameter("codeMode", config->code_mode());
|
||||
print_parameter("codeModuleName", config->module_name());
|
||||
print_parameter("isWeightFile", config->is_weight_file());
|
||||
print_parameter("debugMode", config->debug_mode());
|
||||
|
||||
return RET_OK;
|
||||
|
|
|
@ -73,9 +73,6 @@ class Configurator {
|
|||
void set_debug_mode(bool debug) { debug_mode_ = debug; }
|
||||
bool debug_mode() const { return debug_mode_; }
|
||||
|
||||
void set_is_weight_file(bool flag) { is_weight_file_ = flag; }
|
||||
bool is_weight_file() const { return is_weight_file_; }
|
||||
|
||||
void set_support_parallel(bool parallel) { support_parallel_ = parallel; }
|
||||
bool support_parallel() const { return support_parallel_; }
|
||||
|
||||
|
@ -87,7 +84,6 @@ class Configurator {
|
|||
std::string code_path_;
|
||||
Target target_{kTargetUnknown};
|
||||
CodeMode code_mode_{Code_Unknown};
|
||||
bool is_weight_file_{false};
|
||||
bool support_parallel_{false};
|
||||
bool debug_mode_{false};
|
||||
};
|
||||
|
|
|
@ -137,7 +137,7 @@ int Generator::CodeWeightFile() {
|
|||
cofs << "#include \"" << net_weight_hfile_ << "\"\n\n";
|
||||
cofs << "unsigned char * " << ctx_->buffer_name() << " = 0 ; \n";
|
||||
|
||||
if (config_->is_weight_file()) {
|
||||
if (config_->target() != kARM32M) {
|
||||
std::string net_file = net_src_file_path_ + config_->module_name() + ".net";
|
||||
SaveDataToNet(ctx_->saved_weights(), net_file);
|
||||
CodeModelParamsForNet(hofs, cofs, ctx_);
|
||||
|
|
|
@ -40,7 +40,7 @@ int InferenceGenerator::CodeNetHFile() {
|
|||
if (config_->support_parallel()) {
|
||||
CodeSetGlobalThreadPoolState(ofs, config_->module_name());
|
||||
}
|
||||
if (config_->is_weight_file()) {
|
||||
if (config_->target() != kARM32M) {
|
||||
CodeInitWeightState(ofs, config_->module_name());
|
||||
}
|
||||
CodeManageResourceState(ofs, config_->module_name());
|
||||
|
@ -82,7 +82,7 @@ int InferenceGenerator::CodeBenchmarkFile() {
|
|||
|
||||
CodeBenchmarkSetInputs(ofs, config_->module_name(), ctx_);
|
||||
CodeBenchmarkSetBuffer(ofs, config_->module_name());
|
||||
if (config_->is_weight_file()) {
|
||||
if (config_->target() != kARM32M) {
|
||||
CodeBenchmarkInitWeight(ofs, config_->module_name());
|
||||
}
|
||||
if (config_->support_parallel()) {
|
||||
|
|
|
@ -45,7 +45,7 @@ int TrainGenerator::CodeNetHFile() {
|
|||
ofs << "#include \"microtensor.h\"\n\n";
|
||||
CodeTrainParams(ofs);
|
||||
CodeInputAndOutputState(ofs, config_->module_name());
|
||||
if (config_->is_weight_file()) {
|
||||
if (config_->target() != kARM32M) {
|
||||
CodeInitWeightState(ofs, config_->module_name());
|
||||
}
|
||||
CodeManageResourceState(ofs, config_->module_name());
|
||||
|
@ -84,7 +84,7 @@ int TrainGenerator::CodeBenchmarkFile() {
|
|||
CodeBenchmarkWarmup(ofs, config_->module_name());
|
||||
CodeBenchmarkSetInputs(ofs, config_->module_name(), ctx_);
|
||||
CodeBenchmarkSetBuffer(ofs, config_->module_name());
|
||||
if (config_->is_weight_file()) {
|
||||
if (config_->target() != kARM32M) {
|
||||
CodeBenchmarkInitWeight(ofs, config_->module_name());
|
||||
}
|
||||
CodeBenchmarkInference(ofs, config_->module_name());
|
||||
|
|
|
@ -42,10 +42,10 @@ class OperatorCoder {
|
|||
node_(node),
|
||||
node_index_(node_index) {
|
||||
allocator_ = MemoryAllocator::GetInstance();
|
||||
// vectors checked not empty in OpCoderBuilder::build
|
||||
input_tensor_ = input_tensors_.at(kInputIndex);
|
||||
output_tensor_ = output_tensors_.at(kOutputIndex);
|
||||
}
|
||||
|
||||
std::string name() const { return node_->name_; }
|
||||
|
||||
void set_input_tensor_indices(const std::vector<uint32_t> &input_indices);
|
||||
|
|
Loading…
Reference in New Issue