diff --git a/mindspore/lite/examples/transfer_learning/src/net_runner.cc b/mindspore/lite/examples/transfer_learning/src/net_runner.cc index d400d95e8a3..3e5e87bf48e 100644 --- a/mindspore/lite/examples/transfer_learning/src/net_runner.cc +++ b/mindspore/lite/examples/transfer_learning/src/net_runner.cc @@ -16,9 +16,11 @@ #include "src/net_runner.h" #include +#include #include #include #include +#include #include #include #include "include/context.h" @@ -183,6 +185,7 @@ int NetRunner::TrainLoop() { session_->Train(); float min_loss = 1000.; float max_acc = 0.; + auto start_time = std::chrono::high_resolution_clock::now(); for (int i = 0; i < cycles_; i++) { FillInputData(ds_.train_data()); session_->RunGraph(nullptr, verbose_ ? after_callback : nullptr); @@ -205,6 +208,13 @@ int NetRunner::TrainLoop() { if (acc > kThreshold) return 0; } } + auto end_time = std::chrono::high_resolution_clock::now(); + auto time_cost = std::chrono::duration(end_time - start_time); + if (cycles_ > 0) { + std::cout << "AvgRunTime: " << time_cost.count() / cycles_ << " ms" << std::endl; + } + struct mallinfo info = mallinfo(); + std::cout << "Total allocation: " << info.arena + info.hblkhd << std::endl; return 0; } diff --git a/mindspore/lite/src/lite_model.cc b/mindspore/lite/src/lite_model.cc index bc63a058f1b..dcf4cd05a07 100644 --- a/mindspore/lite/src/lite_model.cc +++ b/mindspore/lite/src/lite_model.cc @@ -24,6 +24,7 @@ #include #include "src/common/prim_util.h" #include "src/common/graph_util.h" +#include "src/common/file_utils.h" #ifdef ENABLE_V0 #include "src/ops/compat/compat_register.h" #endif @@ -416,56 +417,15 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { return model; } -std::unique_ptr ReadFileToBuf(const std::string &filename, size_t *size) { - std::ifstream ifs(filename, std::ifstream::in | std::ifstream::binary); - if (!ifs.good()) { - MS_LOG(ERROR) << "File: " << filename << " does not exist"; - return std::unique_ptr(nullptr); - } - - if (!ifs.is_open()) { - MS_LOG(ERROR) << "File: " << filename << " open failed"; - return std::unique_ptr(nullptr); - } - - ifs.seekg(0, std::ios::end); - auto tellg_ret = ifs.tellg(); - if (tellg_ret <= 0) { - MS_LOG(ERROR) << "Could not read file " << filename; - return std::unique_ptr(nullptr); - } - size_t fsize = static_cast(tellg_ret); - - std::unique_ptr buf(new (std::nothrow) char[fsize]); - if (buf == nullptr) { - MS_LOG(ERROR) << "malloc buf failed, file: " << filename; - ifs.close(); - return std::unique_ptr(nullptr); - } - - ifs.seekg(0, std::ios::beg); - ifs.read(buf.get(), fsize); - if (!ifs) { - MS_LOG(ERROR) << "only read " << ifs.gcount() << "bytes in " << filename; - ifs.close(); - return std::unique_ptr(nullptr); - } - ifs.close(); - if (size != nullptr) { - *size = fsize; - } - return buf; -} - Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); } Model *Model::Import(const char *filename) { size_t size = -1; - auto buf = ReadFileToBuf(filename, &size); + auto buf = ReadFile(filename, &size); if (buf == nullptr) { return nullptr; } - return ImportFromBuffer(buf.get(), size, false); + return ImportFromBuffer(buf, size, false); } int Model::Export(Model *model, char *buffer, size_t *len) { diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index 39f316e4af8..faa706fc160 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -678,6 +678,10 @@ bool TrainSession::IsBN(kernel::LiteKernel *kernel) const { int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type, FormatType format) { + if (file_name.empty()) { + MS_LOG(ERROR) << "File name cannot be empty"; + return RET_ERROR; + } if (format != FT_FLATBUFFERS) { MS_LOG(ERROR) << "Currently only flatbuffer format is supported"; return RET_ERROR; diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h index e7127262d4a..929ad786e6d 100644 --- a/mindspore/lite/src/train/train_session.h +++ b/mindspore/lite/src/train/train_session.h @@ -42,7 +42,6 @@ namespace mindspore { namespace lite { -std::unique_ptr ReadFileToBuf(const std::string &filename, size_t *size); using CreatorOp = std::tuple; class TrainSession : virtual public lite::LiteSession { public: @@ -58,6 +57,8 @@ class TrainSession : virtual public lite::LiteSession { int Train() override; int Eval() override; + bool IsTrain() override { return train_mode_; } + bool IsEval() override { return !train_mode_; } int SetLearningRate(float learning_rate) override; float GetLearningRate() override; int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) override; diff --git a/mindspore/lite/src/train/transfer_session.cc b/mindspore/lite/src/train/transfer_session.cc index d40b752b38a..e11f60ab779 100644 --- a/mindspore/lite/src/train/transfer_session.cc +++ b/mindspore/lite/src/train/transfer_session.cc @@ -24,6 +24,7 @@ #include #include "include/errorcode.h" #include "src/common/utils.h" +#include "src/common/file_utils.h" #include "src/tensor.h" #include "src/train/loss_kernel.h" #include "src/train/optimizer_kernel.h" @@ -300,15 +301,15 @@ session::LiteSession *session::LiteSession::CreateTransferSession(const std::str const lite::TrainCfg *cfg) { size_t size_head = 0; size_t size_backbone = 0; - auto buf_head = lite::ReadFileToBuf(filename_head, &size_head); + auto buf_head = lite::ReadFile(filename_head.c_str(), &size_head); if (buf_head == nullptr) { return nullptr; } - auto buf_backbone = lite::ReadFileToBuf(filename_backbone, &size_backbone); + auto buf_backbone = lite::ReadFile(filename_backbone.c_str(), &size_backbone); if (buf_backbone == nullptr) { return nullptr; } - return CreateTransferSessionInt(buf_backbone.get(), size_backbone, buf_head.get(), size_head, ctxt, train_mode, cfg); + return CreateTransferSessionInt(buf_backbone, size_backbone, buf_head, size_head, ctxt, train_mode, cfg); } } // namespace mindspore diff --git a/mindspore/lite/tools/anf_exporter/fetch_content.cc b/mindspore/lite/tools/anf_exporter/fetch_content.cc index 86be81b92c2..2dd9cfb7ff5 100644 --- a/mindspore/lite/tools/anf_exporter/fetch_content.cc +++ b/mindspore/lite/tools/anf_exporter/fetch_content.cc @@ -283,9 +283,14 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F MS_LOG(ERROR) << "fetch information from default param failed."; return RET_ERROR; } - - // attr weightFormat is only used by conv-like ops' second input auto prim = GetValueNode(cnode->input(0)); + if (prim->GetAttr(ops::kFormat) != nullptr) { + auto value = prim->GetAttr(ops::kFormat); + if (value->isa()) { + data_info->format_ = GetValue(value); + } + } + // attr weightFormat is only used by conv-like ops' second input if ((opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) || opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) && diff --git a/mindspore/lite/tools/converter/parser/parser_utils.cc b/mindspore/lite/tools/converter/parser/parser_utils.cc index f2c29bdb679..3ea454cc2ee 100644 --- a/mindspore/lite/tools/converter/parser/parser_utils.cc +++ b/mindspore/lite/tools/converter/parser/parser_utils.cc @@ -123,8 +123,8 @@ int GetTransposePermSharing(schema::Format src_format, schema::Format dst_format return lite::RET_OK; } -int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node, - std::vector perm) { +int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t dst_format, int64_t format, + const ParameterPtr &weight_node, std::vector perm) { MS_ASSERT(graph != nullptr); MS_ASSERT(weight_node != nullptr); auto node_list = TopoSort(graph->get_return()); @@ -158,6 +158,7 @@ int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t format, c auto perm_node = opt::BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_sharing_perm"); auto prim = std::make_shared(); prim->AddAttr("quant_params", std::make_shared(1, 1)); + prim->AddAttr(ops::kFormat, MakeValue(dst_format)); auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node}); if (!weight_node->has_default()) { MS_LOG(DEBUG) << "Weight parameter should has default parameter."; @@ -198,7 +199,7 @@ int HandleWeightSharing(const FuncGraphPtr &graph, int64_t format, const Paramet MS_LOG(ERROR) << "get perm failed."; return status; } - status = TransposeInsertForWeightSharing(graph, format, weight_node, perm); + status = TransposeInsertForWeightSharing(graph, dst_format, format, weight_node, perm); if (status != lite::RET_OK) { MS_LOG(ERROR) << "transpose insert failed."; } diff --git a/mindspore/lite/tools/converter/parser/parser_utils.h b/mindspore/lite/tools/converter/parser/parser_utils.h index a1866e317b4..d34379367d2 100644 --- a/mindspore/lite/tools/converter/parser/parser_utils.h +++ b/mindspore/lite/tools/converter/parser/parser_utils.h @@ -34,8 +34,8 @@ int TransposeInsertForWeightConst(const FuncGraphPtr &graph, const CNodePtr &con std::vector perm); int HandleWeightConst(const FuncGraphPtr &graph, const CNodePtr &conv_node, const CNodePtr &weight_node, schema::Format src_format, schema::Format dst_format); -int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node, - std::vector perm); +int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t dst_format, int64_t format, + const ParameterPtr &weight_node, std::vector perm); int HandleWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node, schema::Format src_format, schema::Format dst_format); } // namespace lite