!49339 modify save_type api doc and trainModel para

Merge pull request !49339 from 周超/master5
This commit is contained in:
i-robot 2023-03-02 11:57:04 +00:00 committed by Gitee
commit ca63485b81
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 12 additions and 8 deletions

View File

@ -7,9 +7,9 @@ mindspore_lite.ModelType
适用于以下场景:
1. Converter时设置 `save_type` 参数, `ModelType` 用于定义转换生成的模型类型。
1. 调用 `mindspore_lite.Converter`时,设置 `save_type` 参数, `ModelType` 用于定义转换生成的模型类型。
2. Converter之后当从文件加载或构建模型以进行推理时 `ModelType` 用于定义输入模型框架类型。
2. 调用 `mindspore_lite.Converter`之后,当从文件加载或构建模型以进行推理时, `ModelType` 用于定义输入模型框架类型。
目前,支持以下 `ModelType`

View File

@ -32,10 +32,11 @@ class ModelType(Enum):
Used in the following scenarios:
1. When Converter, set `save_type` parameter, `ModelType` used to define the model type generated by Converter.
1. When using 'mindspore_lite.Converter', set `save_type` parameter, `ModelType` used to define the model type
generated by Converter.
2. After Converter, When loading or building a model from file for predicting, the `ModelType` is used to define
Input model framework type.
2. After using 'mindspore_lite.Converter', when loading or building a model from file for predicting, the
`ModelType` is used to define Input model framework type.
Currently, the following `ModelType` are supported:

View File

@ -142,12 +142,12 @@ function Convert() {
echo "./converter_lite --fmk=${model_fmk} --modelFile=${model_file} --weightFile=${weight_file} --outputFile=${output_file}\
--inputDataType=${in_dtype} --outputDataType=${out_dtype} --inputShape=${spec_shapes} --fp16=${fp16_weight}\
--configFile=${config_file} --saveType=${save_type} --optimize=${optimize} \
--trainModel=${train_model} --inputDataFormat=${input_format}"
--inputDataFormat=${input_format}"
./converter_lite --fmk=${model_fmk} --modelFile=${model_file} --weightFile=${weight_file} --outputFile=${output_file}\
--inputDataType=${in_dtype} --outputDataType=${out_dtype} --inputShape="${spec_shapes}" --fp16=${fp16_weight}\
--configFile=${config_file} --saveType=${save_type} --optimize=${optimize} \
--trainModel=${train_model} --inputDataFormat=${input_format} >> "$4"
--inputDataFormat=${input_format} >> "$4"
else
echo "./converter_lite --fmk=${model_fmk} --modelFile=${model_file} --weightFile=${weight_file} --outputFile=${output_file}\
--inputDataType=${in_dtype} --outputDataType=${out_dtype} --inputShape=${spec_shapes} --fp16=${fp16_weight}\

View File

@ -50,10 +50,13 @@ Flags::Flags() {
AddFlag(&Flags::saveFP16Str, "fp16",
"Serialize const tensor in Float16 data type, only effective for const tensor in Float32 data type. on | off",
"off");
// Cloud infer do not support trainModel para
#if !defined(ENABLE_CLOUD_FUSION_INFERENCE) && !defined(ENABLE_CLOUD_INFERENCE)
AddFlag(&Flags::trainModelIn, "trainModel",
"whether the model is going to be trained on device. "
"true | false",
"false");
#endif
AddFlag(&Flags::dec_key, "decryptKey",
"The key used to decrypt the file, expressed in hexadecimal characters. Only valid when fmkIn is 'MINDIR'",
"");

View File

@ -60,7 +60,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
std::string outputDataTypeStr;
DataType outputDataType;
std::string configFile;
std::string trainModelIn;
std::string trainModelIn = "false";
bool trainModel = false;
std::string inTensorShape;
mutable std::map<std::string, std::vector<int64_t>> graph_input_shape_map;