!48483 [Feature] Build Model适用大模型场景

Merge pull request !48483 from douzhixing/auto-tune
This commit is contained in:
i-robot 2023-02-10 06:40:51 +00:00 committed by Gitee
commit 7c1c88c0ba
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 81 additions and 24 deletions

View File

@ -178,15 +178,16 @@ bool MetaGraphSerializer::ExtraAndSerializeModelWeight(const schema::MetaGraphT
}
bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key,
const size_t key_len, const std::string &enc_mode) {
const size_t key_len, const std::string &enc_mode,
size_t *size) {
// serialize model
flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
auto offset = schema::MetaGraph::Pack(builder, &meta_graphT);
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
size_t size = builder.GetSize();
*size = builder.GetSize();
auto content = builder.GetBufferPointer();
if (!SerializeModel(content, size, key, key_len, enc_mode)) {
if (!SerializeModel(content, *size, key, key_len, enc_mode)) {
MS_LOG(ERROR) << "Serialize graph failed";
return false;
}
@ -221,21 +222,33 @@ uint8_t *MetaGraphSerializer::GetMetaGraphPackedBuff(flatbuffers::FlatBufferBuil
int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key,
const size_t key_len, const std::string &enc_mode) {
MetaGraphSerializer meta_graph_serializer;
size_t size = 0;
auto ret = MetaGraphSerializer::Save(graph, output_path, &size, key, key_len, enc_mode);
return ret;
}
int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path, size_t *size,
const Byte *key, const size_t key_len, const std::string &enc_mode) {
MetaGraphSerializer meta_graph_serializer;
*size = 0;
flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
auto buffer = meta_graph_serializer.GetMetaGraphPackedBuff(&builder, graph, &size);
auto buffer = meta_graph_serializer.GetMetaGraphPackedBuff(&builder, graph, size);
if (!meta_graph_serializer.InitPath(output_path)) {
MS_LOG(ERROR) << "Init path failed";
return RET_ERROR;
}
auto save_together = (size < kModelSizeLimit);
size_t tensors_size = 0;
for (auto &tensor : graph.allTensors) {
tensors_size += tensor->data.size();
}
auto save_together = (tensors_size < kModelSizeLimit && *size < kModelSizeLimit);
if (!meta_graph_serializer.Init(graph, save_together)) {
MS_LOG(ERROR) << "Init MetaGraphSerializer failed";
return RET_ERROR;
}
if (save_together) {
if (!meta_graph_serializer.SerializeModel(buffer, size, key, key_len, enc_mode)) {
if (!meta_graph_serializer.SerializeModel(buffer, *size, key, key_len, enc_mode)) {
MS_LOG(ERROR) << "Serialize graph failed";
return RET_ERROR;
}
@ -244,10 +257,12 @@ int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string
MS_LOG(ERROR) << "Serialize graph weight failed";
return RET_ERROR;
}
if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph, key, key_len, enc_mode)) {
size_t model_size = 0;
if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph, key, key_len, enc_mode, &model_size)) {
MS_LOG(ERROR) << "Serialize graph and adjust weight failed";
return RET_ERROR;
}
*size = model_size + tensors_size;
}
return RET_OK;
}

View File

@ -33,6 +33,8 @@ class MetaGraphSerializer {
// save serialized fb model
static int Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key = {},
const size_t key_len = 0, const std::string &enc_mode = "");
static int Save(const schema::MetaGraphT &graph, const std::string &output_path, size_t *size, const Byte *key = {},
const size_t key_len = 0, const std::string &enc_mode = "");
private:
MetaGraphSerializer() = default;
@ -48,7 +50,7 @@ class MetaGraphSerializer {
bool ExtraAndSerializeModelWeight(const schema::MetaGraphT &graph);
bool SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key, const size_t key_len,
const std::string &enc_mode);
const std::string &enc_mode, size_t *size = 0);
bool SerializeModel(const void *content, size_t size, const Byte *key, const size_t key_len,
const std::string &enc_mode);

View File

@ -345,10 +345,9 @@ int ConfigFileParser::ParseMixedBitQuantString(const std::map<std::string, std::
if (maps.find(kMixedBitWeightQuantParam) != maps.end()) {
const auto &map = maps.at(kMixedBitWeightQuantParam);
std::map<std::string, std::string &> parse_map{
{"init_scale", mixed_bit_quant_string_.init_scale},
{"auto_tune", mixed_bit_quant_string_.auto_tune},
{"use_cv_data", mixed_bit_quant_string_.use_cv_data},
{"max_iterations", mixed_bit_quant_string_.max_iterations},
{"init_scale", mixed_bit_quant_string_.init_scale}, {"auto_tune", mixed_bit_quant_string_.auto_tune},
{"use_cv_data", mixed_bit_quant_string_.use_cv_data}, {"max_iterations", mixed_bit_quant_string_.max_iterations},
{"workspace", mixed_bit_quant_string_.workspace},
};
return SetMapData(map, parse_map, kMixedBitWeightQuantParam);
}

View File

@ -53,6 +53,7 @@ struct MixedBitWeightQuantString {
std::string auto_tune;
std::string use_cv_data;
std::string max_iterations;
std::string workspace;
};
struct WeightQuantString {

View File

@ -170,6 +170,9 @@ int QuantParamParser::ParseMixedBitWeightQuant(const MixedBitWeightQuantString &
MS_LOG(ERROR) << "INPUT ILLEGAL: auto_tune should be true or false.";
return RET_INPUT_PARAM_INVALID;
}
// this is required only for model larger than 2G
mixed_bit_weight_quant->workspace = mixed_bit_weight_quant_string.workspace;
return RET_OK;
}

View File

@ -129,6 +129,7 @@ struct MixedBitWeightQuantParam {
bool auto_tune = false;
bool use_cv_data = false;
int max_iterations = kMinIterations;
std::string workspace; // only for model larger than 2G
};
struct FullQuantParam {

View File

@ -47,6 +47,8 @@ namespace mindspore::lite::quant {
namespace {
constexpr size_t kGatherAxisIndex = 3;
constexpr int kDefaultThreadNum = 4;
constexpr size_t kEncMaxLen = 16;
constexpr size_t kModelSizeLimit = static_cast<size_t>(2) * 1024 * 1024 * 1024;
} // namespace
int GetQuantType(const CNodePtr &cnode, quant::QuantType *quant_type) {
@ -204,6 +206,27 @@ std::string NodePrimitiveType(const CNodePtr &cnode) {
return primitive_c->name();
}
Status LargeModelBuildModel(const schema::MetaGraphT &meta_graph, const std::shared_ptr<ConverterPara> &param,
const std::shared_ptr<mindspore::Model> &model, const std::shared_ptr<Context> &context,
size_t *size) {
if (param->mixedBitWeightQuantParam.workspace.empty()) {
MS_LOG(ERROR) << "The model is larger than 2G, mixedBitWeightQuant config needs to set workspace to save tmp model";
return kLiteError;
}
std::string tmp_save_file_path = param->mixedBitWeightQuantParam.workspace + "/tmp.ms";
unsigned char encKey[kEncMaxLen] = {0};
size_t keyLen = 0;
auto status = MetaGraphSerializer::Save(meta_graph, tmp_save_file_path, size, encKey, keyLen, param->encrypt_mode);
if (status != RET_OK) {
MS_LOG(ERROR) << "Save Large Model Failed: " << status << " " << GetErrorInfo(status);
return kLiteError;
}
mindspore::ModelType model_type = kMindIR_Lite;
auto ret = model->Build(tmp_save_file_path, model_type, context);
return ret;
}
Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, const FuncGraphPtr &func_graph,
const std::shared_ptr<ConverterPara> &param, size_t *size) {
FuncGraphPtr func_graph_clone;
@ -236,17 +259,6 @@ Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, con
return kLiteError;
}
flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
auto offset = schema::MetaGraph::Pack(builder, meta_graph);
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
*size = builder.GetSize();
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
if (content == nullptr) {
MS_LOG(ERROR) << "GetBufferPointer return null";
delete meta_graph;
return kLiteNullptr;
}
auto context = std::make_shared<mindspore::Context>();
if (context == nullptr) {
MS_LOG(ERROR) << "New context failed while running.";
@ -264,6 +276,30 @@ Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, con
}
auto &device_list = context->MutableDeviceInfo();
device_list.push_back(device_info);
size_t tensors_size = 0;
for (auto &tensor : meta_graph->allTensors) {
tensors_size += tensor->data.size();
}
if (tensors_size >= kModelSizeLimit) {
auto ret = LargeModelBuildModel(*meta_graph, param, model, context, size);
delete meta_graph;
return ret;
}
flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
auto offset = schema::MetaGraph::Pack(builder, meta_graph);
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
*size = builder.GetSize();
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
if (content == nullptr) {
MS_LOG(ERROR) << "GetBufferPointer return null";
delete meta_graph;
return kLiteNullptr;
}
auto ret = model->Build(content, *size, kMindIR, context);
delete meta_graph;
return ret;