forked from mindspore-Ecosystem/mindspore
!18818 [MS][LITE] fix lite converter issue and security problem and add time print
Merge pull request !18818 from zhengjun10/master
This commit is contained in:
commit
43174475e6
|
@ -16,9 +16,11 @@
|
|||
|
||||
#include "src/net_runner.h"
|
||||
#include <getopt.h>
|
||||
#include <malloc.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <chrono>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include "include/context.h"
|
||||
|
@ -183,6 +185,7 @@ int NetRunner::TrainLoop() {
|
|||
session_->Train();
|
||||
float min_loss = 1000.;
|
||||
float max_acc = 0.;
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
for (int i = 0; i < cycles_; i++) {
|
||||
FillInputData(ds_.train_data());
|
||||
session_->RunGraph(nullptr, verbose_ ? after_callback : nullptr);
|
||||
|
@ -205,6 +208,13 @@ int NetRunner::TrainLoop() {
|
|||
if (acc > kThreshold) return 0;
|
||||
}
|
||||
}
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto time_cost = std::chrono::duration<double, std::milli>(end_time - start_time);
|
||||
if (cycles_ > 0) {
|
||||
std::cout << "AvgRunTime: " << time_cost.count() / cycles_ << " ms" << std::endl;
|
||||
}
|
||||
struct mallinfo info = mallinfo();
|
||||
std::cout << "Total allocation: " << info.arena + info.hblkhd << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <memory>
|
||||
#include "src/common/prim_util.h"
|
||||
#include "src/common/graph_util.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#ifdef ENABLE_V0
|
||||
#include "src/ops/compat/compat_register.h"
|
||||
#endif
|
||||
|
@ -416,56 +417,15 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
|
|||
return model;
|
||||
}
|
||||
|
||||
std::unique_ptr<char[]> ReadFileToBuf(const std::string &filename, size_t *size) {
|
||||
std::ifstream ifs(filename, std::ifstream::in | std::ifstream::binary);
|
||||
if (!ifs.good()) {
|
||||
MS_LOG(ERROR) << "File: " << filename << " does not exist";
|
||||
return std::unique_ptr<char[]>(nullptr);
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
MS_LOG(ERROR) << "File: " << filename << " open failed";
|
||||
return std::unique_ptr<char[]>(nullptr);
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
auto tellg_ret = ifs.tellg();
|
||||
if (tellg_ret <= 0) {
|
||||
MS_LOG(ERROR) << "Could not read file " << filename;
|
||||
return std::unique_ptr<char[]>(nullptr);
|
||||
}
|
||||
size_t fsize = static_cast<size_t>(tellg_ret);
|
||||
|
||||
std::unique_ptr<char[]> buf(new (std::nothrow) char[fsize]);
|
||||
if (buf == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc buf failed, file: " << filename;
|
||||
ifs.close();
|
||||
return std::unique_ptr<char[]>(nullptr);
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(buf.get(), fsize);
|
||||
if (!ifs) {
|
||||
MS_LOG(ERROR) << "only read " << ifs.gcount() << "bytes in " << filename;
|
||||
ifs.close();
|
||||
return std::unique_ptr<char[]>(nullptr);
|
||||
}
|
||||
ifs.close();
|
||||
if (size != nullptr) {
|
||||
*size = fsize;
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
|
||||
Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); }
|
||||
|
||||
Model *Model::Import(const char *filename) {
|
||||
size_t size = -1;
|
||||
auto buf = ReadFileToBuf(filename, &size);
|
||||
auto buf = ReadFile(filename, &size);
|
||||
if (buf == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return ImportFromBuffer(buf.get(), size, false);
|
||||
return ImportFromBuffer(buf, size, false);
|
||||
}
|
||||
|
||||
int Model::Export(Model *model, char *buffer, size_t *len) {
|
||||
|
|
|
@ -678,6 +678,10 @@ bool TrainSession::IsBN(kernel::LiteKernel *kernel) const {
|
|||
|
||||
int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type,
|
||||
FormatType format) {
|
||||
if (file_name.empty()) {
|
||||
MS_LOG(ERROR) << "File name cannot be empty";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (format != FT_FLATBUFFERS) {
|
||||
MS_LOG(ERROR) << "Currently only flatbuffer format is supported";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -42,7 +42,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
std::unique_ptr<char[]> ReadFileToBuf(const std::string &filename, size_t *size);
|
||||
using CreatorOp = std::tuple<mindspore::kernel::KernelKey, mindspore::kernel::KernelCreator>;
|
||||
class TrainSession : virtual public lite::LiteSession {
|
||||
public:
|
||||
|
@ -58,6 +57,8 @@ class TrainSession : virtual public lite::LiteSession {
|
|||
|
||||
int Train() override;
|
||||
int Eval() override;
|
||||
bool IsTrain() override { return train_mode_; }
|
||||
bool IsEval() override { return !train_mode_; }
|
||||
int SetLearningRate(float learning_rate) override;
|
||||
float GetLearningRate() override;
|
||||
int SetupVirtualBatch(int virtual_batch_multiplier, float lr = -1.0f, float momentum = -1.0f) override;
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <memory>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "src/tensor.h"
|
||||
#include "src/train/loss_kernel.h"
|
||||
#include "src/train/optimizer_kernel.h"
|
||||
|
@ -300,15 +301,15 @@ session::LiteSession *session::LiteSession::CreateTransferSession(const std::str
|
|||
const lite::TrainCfg *cfg) {
|
||||
size_t size_head = 0;
|
||||
size_t size_backbone = 0;
|
||||
auto buf_head = lite::ReadFileToBuf(filename_head, &size_head);
|
||||
auto buf_head = lite::ReadFile(filename_head.c_str(), &size_head);
|
||||
if (buf_head == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto buf_backbone = lite::ReadFileToBuf(filename_backbone, &size_backbone);
|
||||
auto buf_backbone = lite::ReadFile(filename_backbone.c_str(), &size_backbone);
|
||||
if (buf_backbone == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return CreateTransferSessionInt(buf_backbone.get(), size_backbone, buf_head.get(), size_head, ctxt, train_mode, cfg);
|
||||
return CreateTransferSessionInt(buf_backbone, size_backbone, buf_head, size_head, ctxt, train_mode, cfg);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -283,9 +283,14 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
|
|||
MS_LOG(ERROR) << "fetch information from default param failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// attr weightFormat is only used by conv-like ops' second input
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim->GetAttr(ops::kFormat) != nullptr) {
|
||||
auto value = prim->GetAttr(ops::kFormat);
|
||||
if (value->isa<mindspore::Int64Imm>()) {
|
||||
data_info->format_ = GetValue<int64_t>(value);
|
||||
}
|
||||
}
|
||||
// attr weightFormat is only used by conv-like ops' second input
|
||||
if ((opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
|
||||
opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) ||
|
||||
opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) &&
|
||||
|
|
|
@ -123,8 +123,8 @@ int GetTransposePermSharing(schema::Format src_format, schema::Format dst_format
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node,
|
||||
std::vector<int> perm) {
|
||||
int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t dst_format, int64_t format,
|
||||
const ParameterPtr &weight_node, std::vector<int> perm) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
|
@ -158,6 +158,7 @@ int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t format, c
|
|||
auto perm_node = opt::BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_sharing_perm");
|
||||
auto prim = std::make_shared<ops::Transpose>();
|
||||
prim->AddAttr("quant_params", std::make_shared<QuantParamHolder>(1, 1));
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(dst_format));
|
||||
auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node});
|
||||
if (!weight_node->has_default()) {
|
||||
MS_LOG(DEBUG) << "Weight parameter should has default parameter.";
|
||||
|
@ -198,7 +199,7 @@ int HandleWeightSharing(const FuncGraphPtr &graph, int64_t format, const Paramet
|
|||
MS_LOG(ERROR) << "get perm failed.";
|
||||
return status;
|
||||
}
|
||||
status = TransposeInsertForWeightSharing(graph, format, weight_node, perm);
|
||||
status = TransposeInsertForWeightSharing(graph, dst_format, format, weight_node, perm);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "transpose insert failed.";
|
||||
}
|
||||
|
|
|
@ -34,8 +34,8 @@ int TransposeInsertForWeightConst(const FuncGraphPtr &graph, const CNodePtr &con
|
|||
std::vector<int> perm);
|
||||
int HandleWeightConst(const FuncGraphPtr &graph, const CNodePtr &conv_node, const CNodePtr &weight_node,
|
||||
schema::Format src_format, schema::Format dst_format);
|
||||
int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node,
|
||||
std::vector<int> perm);
|
||||
int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t dst_format, int64_t format,
|
||||
const ParameterPtr &weight_node, std::vector<int> perm);
|
||||
int HandleWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node,
|
||||
schema::Format src_format, schema::Format dst_format);
|
||||
} // namespace lite
|
||||
|
|
Loading…
Reference in New Issue