!24924 modify ckpt bug and clean code

Merge pull request !24924 from changzherui/mod_ckpt_and_clean_code
This commit is contained in:
i-robot 2021-10-19 01:57:11 +00:00 committed by Gitee
commit d8363eff1b
8 changed files with 30 additions and 45 deletions

View File

@ -1,13 +0,0 @@
if(ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz")
set(MD5 "1bdbcecdd68ea8392630467646776e02")
else()
set(REQ_URL "https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz")
set(MD5 "512f2779d6215d4a36f366b6b9acdf1e")
endif()
mindspore_add_pkg(ms_onnx
VER 1.6.0
HEAD_ONLY ./
URL ${REQ_URL}
MD5 ${MD5})

View File

@ -100,5 +100,4 @@ if(ENABLE_TESTCASES OR ENABLE_CPP_ST)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/gtest.cmake)
endif()
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/onnx.cmake)
set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS_F})

View File

@ -81,7 +81,7 @@ std::shared_ptr<T> ParserAttr(const std::string &str, const std::unordered_map<s
int count = 0;
for (size_t i = 0; i < str.length(); i++) {
if (str[i] == '[') {
rules.push("[");
rules.push(std::string("["));
} else if (str[i] == ']') {
// rules
std::vector<P> vec;
@ -107,7 +107,7 @@ std::shared_ptr<T> ParserAttr(const std::string &str, const std::unordered_map<s
} else {
count++;
if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') {
auto value_name = str.substr(i - count + 1, count);
auto value_name = str.substr(static_cast<int>(i) - count + 1, count);
if (kv.find(value_name) == kv.end()) {
MS_LOG(ERROR) << "Node's attributes and shape do not match.";
return nullptr;
@ -290,7 +290,8 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
std::string initial_data = parameter_proto.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
MS_EXCEPTION_IF_NULL(tensor_data_buf);
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), initial_data.data(), initial_data.size());
auto ret = memcpy_s(tensor_data_buf, static_cast<size_t>(tensor_info->data().nbytes()), initial_data.data(),
initial_data.size());
if (ret != 0) {
MS_LOG(ERROR) << "Build parameter occur memcpy_s error.";
return false;
@ -560,7 +561,8 @@ bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
MS_EXCEPTION_IF_NULL(tensor_info);
const std::string &tensor_buf = attr_tensor.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
auto ret =
memcpy_s(tensor_data_buf, static_cast<size_t>(tensor_info->data().nbytes()), tensor_buf.data(), tensor_buf.size());
if (ret != 0) {
MS_LOG(ERROR) << "Obtain CNode in TensorForm occur memcpy_s error.";
return false;

View File

@ -42,8 +42,8 @@ class MindIREngine {
void Init(const AbstractBasePtrList &args);
static AbstractBasePtr InferPrimitiveShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);
void EvalCommonPrimitive(const PrimitivePtr &prim, const CNodePtr &node, const AbstractBasePtrListPtr &args);
void EvalPartialPrimitive(const PrimitivePtr &prim, const CNodePtr &node, const AbstractBasePtrListPtr &args);
void EvalReturnPrimitive(const PrimitivePtr &prim, const CNodePtr &node);
void EvalPartialPrimitive(const CNodePtr &node, const AbstractBasePtrListPtr &args);
void EvalReturnPrimitive(const CNodePtr &node);
void InferParameter(const AnfNodePtr &node);
void InferValueNode(const AnfNodePtr &node);
void InferCNode(const AnfNodePtr &node);
@ -195,7 +195,7 @@ void MindIREngine::EvalCommonPrimitive(const PrimitivePtr &prim, const CNodePtr
SaveNodeInferResult(node, result);
}
void MindIREngine::EvalReturnPrimitive(const PrimitivePtr &prim, const CNodePtr &node) {
void MindIREngine::EvalReturnPrimitive(const CNodePtr &node) {
if (node->inputs().size() < 2) {
MS_LOG(EXCEPTION) << node->DebugString() << " input size < 2";
}
@ -220,8 +220,7 @@ void MindIREngine::EvalReturnPrimitive(const PrimitivePtr &prim, const CNodePtr
}
}
void MindIREngine::EvalPartialPrimitive(const PrimitivePtr &prim, const CNodePtr &node,
const AbstractBasePtrListPtr &args) {
void MindIREngine::EvalPartialPrimitive(const CNodePtr &node, const AbstractBasePtrListPtr &args) {
// Args has been resolved
if (args != nullptr) {
if (args->size() < 2) {
@ -307,12 +306,12 @@ void MindIREngine::EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosu
auto prim = func->prim();
// Return Primitive
if (prim->name() == prim::kPrimReturn->name()) {
EvalReturnPrimitive(prim, node);
EvalReturnPrimitive(node);
return;
}
// Partial Primitive
if (prim->name() == prim::kPrimPartial->name()) {
EvalPartialPrimitive(prim, node, args);
EvalPartialPrimitive(node, args);
return;
}
// common Primitive

View File

@ -121,7 +121,7 @@ bool get_all_files(const std::string &dir_in, std::vector<std::string> *files) {
return true;
}
int endsWith(string s, string sub) { return s.rfind(sub) == (s.length() - sub.length()) ? 1 : 0; }
int endsWith(const string s, const string sub) { return s.rfind(sub) == (s.length() - sub.length()) ? 1 : 0; }
bool ParseModelProto(mind_ir::ModelProto *model, const std::string &path, const unsigned char *dec_key,
const size_t key_len, const std::string &dec_mode) {
@ -175,13 +175,12 @@ std::string LoadPreprocess(const std::string &file_name) {
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
return nullptr;
}
const char *file_path = file_name.c_str();
char abs_path_buff[PATH_MAX];
#ifdef _WIN32
_fullpath(abs_path_buff, file_path, PATH_MAX);
_fullpath(abs_path_buff, file_name.c_str(), PATH_MAX);
#else
if (!realpath(file_path, abs_path_buff)) {
if (!realpath(file_name.c_str(), abs_path_buff)) {
MS_LOG(ERROR) << "Load MindIR get absolute path failed";
}
#endif
@ -215,14 +214,13 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
return nullptr;
}
const char *file_path = file_name.c_str();
char abs_path_buff[PATH_MAX];
vector<string> files;
#ifdef _WIN32
_fullpath(abs_path_buff, file_path, PATH_MAX);
_fullpath(abs_path_buff, file_name.c_str(), PATH_MAX);
#else
if (!realpath(file_path, abs_path_buff)) {
if (!realpath(file_name.c_str(), abs_path_buff)) {
MS_LOG(ERROR) << "Load MindIR get absolute path failed";
}
#endif
@ -283,20 +281,19 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
if (!inc_load) {
MSANFModelParser::LoadTensorMapClear();
} else {
model_parser.SetIncLoad();
}
if (is_lite) {
model_parser.SetLite();
}
if (inc_load) {
model_parser.SetIncLoad();
}
FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model);
return dstgraph_ptr;
}
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite) {
MS_EXCEPTION_IF_NULL(buf);
std::string str((const char *)buf, buf_size);
std::string str(buf, buf_size);
mind_ir::ModelProto model_;
if (!model_.ParseFromString(str)) {
MS_LOG(ERROR) << "Parse model from buffer fail!";

View File

@ -82,8 +82,8 @@ class Callback:
Callback function will execute some operations in the current step or epoch.
It holds the information of the model. Such as `network`, `train_network`, `epoch_num`, `batch_num`,
`loss_fn`, `optimizer`, `parallel_mode`, `device_number`, `list_callback`, `cur_epoch_num`,
`cur_step_num`, `dataset_sink_mode`, `net_outputs` and so on.
`loss_fn`, `optimizer`, `parallel_mode`, `device_number`, `list_callback`, `cur_epoch_num`,
`cur_step_num`, `dataset_sink_mode`, `net_outputs` and so on.
Examples:
>>> from mindspore import Model, nn

View File

@ -30,7 +30,7 @@ class LossMonitor(Callback):
If per_print_times is 0, do not print loss.
Args:
per_print_times (int): Print the loss every seconds. Default: 1.
per_print_times (int): How many steps to print once loss. Default: 1.
Raises:
ValueError: If per_print_times is not an integer or less than zero.

View File

@ -482,7 +482,7 @@ def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
raise ValueError("The ckpt_file_name must be string.")
if not os.path.exists(ckpt_file_name):
raise ValueError("The checkpoint file does not exist.")
raise ValueError("The checkpoint file: {} does not exist.".format(ckpt_file_name))
if ckpt_file_name[-5:] != ".ckpt":
raise ValueError("Please input the correct checkpoint file name.")
@ -547,7 +547,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
param_not_load = []
for _, param in net.parameters_and_names():
if param.name in parameter_dict:
new_param = parameter_dict[param.name]
new_param = copy.deepcopy(parameter_dict[param.name])
if not isinstance(new_param, Parameter):
logger.error("Failed to combine the net and the parameters.")
msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param)))
@ -705,6 +705,7 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
inputs (Tensor): Inputs of the `net`, if the network has multiple inputs, incoming tuple(Tensor).
file_name (str): File name of the model to be exported.
file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model.
Default: 'AIR'.
- AIR: Ascend Intermediate Representation. An intermediate representation format of Ascend model.
- ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
@ -713,14 +714,14 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
kwargs (dict): Configuration options dictionary.
- quant_mode (str): If the network is quantization aware training network, the quant_mode should
- quant_mode (str): If the network is a quantization aware training network, the quant_mode should
be set to "QUANT", else the quant_mode should be set to "NONQUANT".
- mean (float): The mean of input data after preprocessing, used for quantizing the first layer of network.
Default: 127.5.
- std_dev (float): The variance of input data after preprocessing,
used for quantizing the first layer of network. Default: 127.5.
- enc_key (byte): Byte type key used for encryption. Tha valid length is 16, 24, or 32.
- enc_mode (str): Specifies the encryption mode, take effect when enc_key is set.
used for quantizing the first layer of the network. Default: 127.5.
- enc_key (byte): Byte type key used for encryption. The valid length is 16, 24, or 32.
- enc_mode (str): Specifies the encryption mode, to take effect when enc_key is set.
Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'.
- dataset (Dataset): Specifies the preprocess methods of network.