!18818 [MS][LITE] fix lite converter issue and security problem and add time print

Merge pull request !18818 from zhengjun10/master
This commit is contained in:
i-robot 2021-06-26 03:29:07 +00:00 committed by Gitee
commit 43174475e6
8 changed files with 36 additions and 54 deletions

View File

@ -16,9 +16,11 @@
#include "src/net_runner.h"
#include <getopt.h>
#include <malloc.h>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <chrono>
#include <fstream>
#include <iostream>
#include "include/context.h"
@ -183,6 +185,7 @@ int NetRunner::TrainLoop() {
session_->Train();
float min_loss = 1000.;
float max_acc = 0.;
auto start_time = std::chrono::high_resolution_clock::now();
for (int i = 0; i < cycles_; i++) {
FillInputData(ds_.train_data());
session_->RunGraph(nullptr, verbose_ ? after_callback : nullptr);
@ -205,6 +208,13 @@ int NetRunner::TrainLoop() {
if (acc > kThreshold) return 0;
}
}
auto end_time = std::chrono::high_resolution_clock::now();
auto time_cost = std::chrono::duration<double, std::milli>(end_time - start_time);
if (cycles_ > 0) {
std::cout << "AvgRunTime: " << time_cost.count() / cycles_ << " ms" << std::endl;
}
struct mallinfo info = mallinfo();
std::cout << "Total allocation: " << info.arena + info.hblkhd << std::endl;
return 0;
}

View File

@ -24,6 +24,7 @@
#include <memory>
#include "src/common/prim_util.h"
#include "src/common/graph_util.h"
#include "src/common/file_utils.h"
#ifdef ENABLE_V0
#include "src/ops/compat/compat_register.h"
#endif
@ -416,56 +417,15 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
return model;
}
std::unique_ptr<char[]> ReadFileToBuf(const std::string &filename, size_t *size) {
std::ifstream ifs(filename, std::ifstream::in | std::ifstream::binary);
if (!ifs.good()) {
MS_LOG(ERROR) << "File: " << filename << " does not exist";
return std::unique_ptr<char[]>(nullptr);
}
if (!ifs.is_open()) {
MS_LOG(ERROR) << "File: " << filename << " open failed";
return std::unique_ptr<char[]>(nullptr);
}
ifs.seekg(0, std::ios::end);
auto tellg_ret = ifs.tellg();
if (tellg_ret <= 0) {
MS_LOG(ERROR) << "Could not read file " << filename;
return std::unique_ptr<char[]>(nullptr);
}
size_t fsize = static_cast<size_t>(tellg_ret);
std::unique_ptr<char[]> buf(new (std::nothrow) char[fsize]);
if (buf == nullptr) {
MS_LOG(ERROR) << "malloc buf failed, file: " << filename;
ifs.close();
return std::unique_ptr<char[]>(nullptr);
}
ifs.seekg(0, std::ios::beg);
ifs.read(buf.get(), fsize);
if (!ifs) {
MS_LOG(ERROR) << "only read " << ifs.gcount() << "bytes in " << filename;
ifs.close();
return std::unique_ptr<char[]>(nullptr);
}
ifs.close();
if (size != nullptr) {
*size = fsize;
}
return buf;
}
Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); }
Model *Model::Import(const char *filename) {
size_t size = -1;
auto buf = ReadFileToBuf(filename, &size);
auto buf = ReadFile(filename, &size);
if (buf == nullptr) {
return nullptr;
}
return ImportFromBuffer(buf.get(), size, false);
return ImportFromBuffer(buf, size, false);
}
int Model::Export(Model *model, char *buffer, size_t *len) {

View File

@ -678,6 +678,10 @@ bool TrainSession::IsBN(kernel::LiteKernel *kernel) const {
int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
FormatType format) {
if (file_name.empty()) {
MS_LOG(ERROR) << "File name cannot be empty";
return RET_ERROR;
}
if (format != FT_FLATBUFFERS) {
MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
return RET_ERROR;

View File

@ -42,7 +42,6 @@
namespace mindspore {
namespace lite {
std::unique_ptr<char[]> ReadFileToBuf(const std::string &filename, size_t *size);
using CreatorOp = std::tuple<mindspore::kernel::KernelKey, mindspore::kernel::KernelCreator>;
class TrainSession : virtual public lite::LiteSession {
public:
@ -58,6 +57,8 @@ class TrainSession : virtual public lite::LiteSession {
int Train() override;
int Eval() override;
bool IsTrain() override { return train_mode_; }
bool IsEval() override { return !train_mode_; }
int SetLearningRate(float learning_rate) override;
float GetLearningRate() override;
int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) override;

View File

@ -24,6 +24,7 @@
#include <memory>
#include "include/errorcode.h"
#include "src/common/utils.h"
#include "src/common/file_utils.h"
#include "src/tensor.h"
#include "src/train/loss_kernel.h"
#include "src/train/optimizer_kernel.h"
@ -300,15 +301,15 @@ session::LiteSession *session::LiteSession::CreateTransferSession(const std::str
const lite::TrainCfg *cfg) {
size_t size_head = 0;
size_t size_backbone = 0;
auto buf_head = lite::ReadFileToBuf(filename_head, &size_head);
auto buf_head = lite::ReadFile(filename_head.c_str(), &size_head);
if (buf_head == nullptr) {
return nullptr;
}
auto buf_backbone = lite::ReadFileToBuf(filename_backbone, &size_backbone);
auto buf_backbone = lite::ReadFile(filename_backbone.c_str(), &size_backbone);
if (buf_backbone == nullptr) {
return nullptr;
}
return CreateTransferSessionInt(buf_backbone.get(), size_backbone, buf_head.get(), size_head, ctxt, train_mode, cfg);
return CreateTransferSessionInt(buf_backbone, size_backbone, buf_head, size_head, ctxt, train_mode, cfg);
}
} // namespace mindspore

View File

@ -283,9 +283,14 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
MS_LOG(ERROR) << "fetch information from default param failed.";
return RET_ERROR;
}
// attr weightFormat is only used by conv-like ops' second input
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->GetAttr(ops::kFormat) != nullptr) {
auto value = prim->GetAttr(ops::kFormat);
if (value->isa<mindspore::Int64Imm>()) {
data_info->format_ = GetValue<int64_t>(value);
}
}
// attr weightFormat is only used by conv-like ops' second input
if ((opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) ||
opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) &&

View File

@ -123,8 +123,8 @@ int GetTransposePermSharing(schema::Format src_format, schema::Format dst_format
return lite::RET_OK;
}
int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node,
std::vector<int> perm) {
int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t dst_format, int64_t format,
const ParameterPtr &weight_node, std::vector<int> perm) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(weight_node != nullptr);
auto node_list = TopoSort(graph->get_return());
@ -158,6 +158,7 @@ int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t format, c
auto perm_node = opt::BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_sharing_perm");
auto prim = std::make_shared<ops::Transpose>();
prim->AddAttr("quant_params", std::make_shared<QuantParamHolder>(1, 1));
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(dst_format));
auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node});
if (!weight_node->has_default()) {
MS_LOG(DEBUG) << "Weight parameter should has default parameter.";
@ -198,7 +199,7 @@ int HandleWeightSharing(const FuncGraphPtr &graph, int64_t format, const Paramet
MS_LOG(ERROR) << "get perm failed.";
return status;
}
status = TransposeInsertForWeightSharing(graph, format, weight_node, perm);
status = TransposeInsertForWeightSharing(graph, dst_format, format, weight_node, perm);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "transpose insert failed.";
}

View File

@ -34,8 +34,8 @@ int TransposeInsertForWeightConst(const FuncGraphPtr &graph, const CNodePtr &con
std::vector<int> perm);
int HandleWeightConst(const FuncGraphPtr &graph, const CNodePtr &conv_node, const CNodePtr &weight_node,
schema::Format src_format, schema::Format dst_format);
int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node,
std::vector<int> perm);
int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t dst_format, int64_t format,
const ParameterPtr &weight_node, std::vector<int> perm);
int HandleWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node,
schema::Format src_format, schema::Format dst_format);
} // namespace lite