forked from mindspore-Ecosystem/mindspore
!48483 [Feature] Build Model适用大模型场景
Merge pull request !48483 from douzhixing/auto-tune
This commit is contained in:
commit
7c1c88c0ba
|
@ -178,15 +178,16 @@ bool MetaGraphSerializer::ExtraAndSerializeModelWeight(const schema::MetaGraphT
|
|||
}
|
||||
|
||||
bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key,
|
||||
const size_t key_len, const std::string &enc_mode) {
|
||||
const size_t key_len, const std::string &enc_mode,
|
||||
size_t *size) {
|
||||
// serialize model
|
||||
flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
|
||||
auto offset = schema::MetaGraph::Pack(builder, &meta_graphT);
|
||||
builder.Finish(offset);
|
||||
schema::FinishMetaGraphBuffer(builder, offset);
|
||||
size_t size = builder.GetSize();
|
||||
*size = builder.GetSize();
|
||||
auto content = builder.GetBufferPointer();
|
||||
if (!SerializeModel(content, size, key, key_len, enc_mode)) {
|
||||
if (!SerializeModel(content, *size, key, key_len, enc_mode)) {
|
||||
MS_LOG(ERROR) << "Serialize graph failed";
|
||||
return false;
|
||||
}
|
||||
|
@ -221,21 +222,33 @@ uint8_t *MetaGraphSerializer::GetMetaGraphPackedBuff(flatbuffers::FlatBufferBuil
|
|||
|
||||
int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key,
|
||||
const size_t key_len, const std::string &enc_mode) {
|
||||
MetaGraphSerializer meta_graph_serializer;
|
||||
size_t size = 0;
|
||||
auto ret = MetaGraphSerializer::Save(graph, output_path, &size, key, key_len, enc_mode);
|
||||
return ret;
|
||||
}
|
||||
|
||||
int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path, size_t *size,
|
||||
const Byte *key, const size_t key_len, const std::string &enc_mode) {
|
||||
MetaGraphSerializer meta_graph_serializer;
|
||||
*size = 0;
|
||||
flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
|
||||
auto buffer = meta_graph_serializer.GetMetaGraphPackedBuff(&builder, graph, &size);
|
||||
auto buffer = meta_graph_serializer.GetMetaGraphPackedBuff(&builder, graph, size);
|
||||
if (!meta_graph_serializer.InitPath(output_path)) {
|
||||
MS_LOG(ERROR) << "Init path failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto save_together = (size < kModelSizeLimit);
|
||||
size_t tensors_size = 0;
|
||||
for (auto &tensor : graph.allTensors) {
|
||||
tensors_size += tensor->data.size();
|
||||
}
|
||||
|
||||
auto save_together = (tensors_size < kModelSizeLimit && *size < kModelSizeLimit);
|
||||
if (!meta_graph_serializer.Init(graph, save_together)) {
|
||||
MS_LOG(ERROR) << "Init MetaGraphSerializer failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (save_together) {
|
||||
if (!meta_graph_serializer.SerializeModel(buffer, size, key, key_len, enc_mode)) {
|
||||
if (!meta_graph_serializer.SerializeModel(buffer, *size, key, key_len, enc_mode)) {
|
||||
MS_LOG(ERROR) << "Serialize graph failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -244,10 +257,12 @@ int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string
|
|||
MS_LOG(ERROR) << "Serialize graph weight failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph, key, key_len, enc_mode)) {
|
||||
size_t model_size = 0;
|
||||
if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph, key, key_len, enc_mode, &model_size)) {
|
||||
MS_LOG(ERROR) << "Serialize graph and adjust weight failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
*size = model_size + tensors_size;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -33,6 +33,8 @@ class MetaGraphSerializer {
|
|||
// save serialized fb model
|
||||
static int Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key = {},
|
||||
const size_t key_len = 0, const std::string &enc_mode = "");
|
||||
static int Save(const schema::MetaGraphT &graph, const std::string &output_path, size_t *size, const Byte *key = {},
|
||||
const size_t key_len = 0, const std::string &enc_mode = "");
|
||||
|
||||
private:
|
||||
MetaGraphSerializer() = default;
|
||||
|
@ -48,7 +50,7 @@ class MetaGraphSerializer {
|
|||
bool ExtraAndSerializeModelWeight(const schema::MetaGraphT &graph);
|
||||
|
||||
bool SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key, const size_t key_len,
|
||||
const std::string &enc_mode);
|
||||
const std::string &enc_mode, size_t *size = 0);
|
||||
|
||||
bool SerializeModel(const void *content, size_t size, const Byte *key, const size_t key_len,
|
||||
const std::string &enc_mode);
|
||||
|
|
|
@ -345,10 +345,9 @@ int ConfigFileParser::ParseMixedBitQuantString(const std::map<std::string, std::
|
|||
if (maps.find(kMixedBitWeightQuantParam) != maps.end()) {
|
||||
const auto &map = maps.at(kMixedBitWeightQuantParam);
|
||||
std::map<std::string, std::string &> parse_map{
|
||||
{"init_scale", mixed_bit_quant_string_.init_scale},
|
||||
{"auto_tune", mixed_bit_quant_string_.auto_tune},
|
||||
{"use_cv_data", mixed_bit_quant_string_.use_cv_data},
|
||||
{"max_iterations", mixed_bit_quant_string_.max_iterations},
|
||||
{"init_scale", mixed_bit_quant_string_.init_scale}, {"auto_tune", mixed_bit_quant_string_.auto_tune},
|
||||
{"use_cv_data", mixed_bit_quant_string_.use_cv_data}, {"max_iterations", mixed_bit_quant_string_.max_iterations},
|
||||
{"workspace", mixed_bit_quant_string_.workspace},
|
||||
};
|
||||
return SetMapData(map, parse_map, kMixedBitWeightQuantParam);
|
||||
}
|
||||
|
|
|
@ -53,6 +53,7 @@ struct MixedBitWeightQuantString {
|
|||
std::string auto_tune;
|
||||
std::string use_cv_data;
|
||||
std::string max_iterations;
|
||||
std::string workspace;
|
||||
};
|
||||
|
||||
struct WeightQuantString {
|
||||
|
|
|
@ -170,6 +170,9 @@ int QuantParamParser::ParseMixedBitWeightQuant(const MixedBitWeightQuantString &
|
|||
MS_LOG(ERROR) << "INPUT ILLEGAL: auto_tune should be true or false.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
// this is required only for model larger than 2G
|
||||
mixed_bit_weight_quant->workspace = mixed_bit_weight_quant_string.workspace;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -129,6 +129,7 @@ struct MixedBitWeightQuantParam {
|
|||
bool auto_tune = false;
|
||||
bool use_cv_data = false;
|
||||
int max_iterations = kMinIterations;
|
||||
std::string workspace; // only for model larger than 2G
|
||||
};
|
||||
|
||||
struct FullQuantParam {
|
||||
|
|
|
@ -47,6 +47,8 @@ namespace mindspore::lite::quant {
|
|||
namespace {
|
||||
constexpr size_t kGatherAxisIndex = 3;
|
||||
constexpr int kDefaultThreadNum = 4;
|
||||
constexpr size_t kEncMaxLen = 16;
|
||||
constexpr size_t kModelSizeLimit = static_cast<size_t>(2) * 1024 * 1024 * 1024;
|
||||
} // namespace
|
||||
|
||||
int GetQuantType(const CNodePtr &cnode, quant::QuantType *quant_type) {
|
||||
|
@ -204,6 +206,27 @@ std::string NodePrimitiveType(const CNodePtr &cnode) {
|
|||
return primitive_c->name();
|
||||
}
|
||||
|
||||
Status LargeModelBuildModel(const schema::MetaGraphT &meta_graph, const std::shared_ptr<ConverterPara> ¶m,
|
||||
const std::shared_ptr<mindspore::Model> &model, const std::shared_ptr<Context> &context,
|
||||
size_t *size) {
|
||||
if (param->mixedBitWeightQuantParam.workspace.empty()) {
|
||||
MS_LOG(ERROR) << "The model is larger than 2G, mixedBitWeightQuant config needs to set workspace to save tmp model";
|
||||
return kLiteError;
|
||||
}
|
||||
std::string tmp_save_file_path = param->mixedBitWeightQuantParam.workspace + "/tmp.ms";
|
||||
unsigned char encKey[kEncMaxLen] = {0};
|
||||
size_t keyLen = 0;
|
||||
auto status = MetaGraphSerializer::Save(meta_graph, tmp_save_file_path, size, encKey, keyLen, param->encrypt_mode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Save Large Model Failed: " << status << " " << GetErrorInfo(status);
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
mindspore::ModelType model_type = kMindIR_Lite;
|
||||
auto ret = model->Build(tmp_save_file_path, model_type, context);
|
||||
return ret;
|
||||
}
|
||||
|
||||
Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, const FuncGraphPtr &func_graph,
|
||||
const std::shared_ptr<ConverterPara> ¶m, size_t *size) {
|
||||
FuncGraphPtr func_graph_clone;
|
||||
|
@ -236,17 +259,6 @@ Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, con
|
|||
return kLiteError;
|
||||
}
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
|
||||
auto offset = schema::MetaGraph::Pack(builder, meta_graph);
|
||||
builder.Finish(offset);
|
||||
schema::FinishMetaGraphBuffer(builder, offset);
|
||||
*size = builder.GetSize();
|
||||
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
|
||||
if (content == nullptr) {
|
||||
MS_LOG(ERROR) << "GetBufferPointer return null";
|
||||
delete meta_graph;
|
||||
return kLiteNullptr;
|
||||
}
|
||||
auto context = std::make_shared<mindspore::Context>();
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "New context failed while running.";
|
||||
|
@ -264,6 +276,30 @@ Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, con
|
|||
}
|
||||
auto &device_list = context->MutableDeviceInfo();
|
||||
device_list.push_back(device_info);
|
||||
|
||||
size_t tensors_size = 0;
|
||||
for (auto &tensor : meta_graph->allTensors) {
|
||||
tensors_size += tensor->data.size();
|
||||
}
|
||||
|
||||
if (tensors_size >= kModelSizeLimit) {
|
||||
auto ret = LargeModelBuildModel(*meta_graph, param, model, context, size);
|
||||
delete meta_graph;
|
||||
return ret;
|
||||
}
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
|
||||
auto offset = schema::MetaGraph::Pack(builder, meta_graph);
|
||||
builder.Finish(offset);
|
||||
schema::FinishMetaGraphBuffer(builder, offset);
|
||||
*size = builder.GetSize();
|
||||
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
|
||||
if (content == nullptr) {
|
||||
MS_LOG(ERROR) << "GetBufferPointer return null";
|
||||
delete meta_graph;
|
||||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
auto ret = model->Build(content, *size, kMindIR, context);
|
||||
delete meta_graph;
|
||||
return ret;
|
||||
|
|
Loading…
Reference in New Issue