!20340 read or write file interface code check

Merge pull request !20340 from zhangbuxue/write_or_read_file_interface_code_check
This commit is contained in:
i-robot 2021-07-21 01:47:02 +00:00 committed by Gitee
commit c168ecce09
4 changed files with 18 additions and 9 deletions

View File

@ -424,7 +424,6 @@ bool ArithmeticCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const
SquaredDifference(input1, input2, output);
} else {
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
return false;
}
return true;
}

View File

@ -213,7 +213,6 @@ bool ArithmeticLogicCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs,
LogicalOr(input1, input2, output);
} else {
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
return false;
}
return true;
}

View File

@ -174,8 +174,7 @@ std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(std::vector<std::string> fil
const std::string &dec_mode, bool inc_load) {
std::vector<std::shared_ptr<FuncGraph>> funcgraph_vec;
MS_LOG(DEBUG) << "Load multiple MindIR files.";
for (size_t i = 0; i < file_names.size(); ++i) {
std::string file_name = file_names[i];
for (const auto &file_name : file_names) {
MS_LOG(DEBUG) << "Load " << file_name;
funcgraph_vec.push_back(LoadMindIR(file_name, is_lite, dec_key, key_len, dec_mode, inc_load));
}
@ -188,10 +187,9 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
return nullptr;
}
const char *file_path = reinterpret_cast<const char *>(file_name.c_str());
const char *file_path = file_name.c_str();
char abs_path_buff[PATH_MAX];
char abs_path[PATH_MAX];
vector<string> files;
#ifdef _WIN32
@ -208,7 +206,11 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
}
// Load parameter into graph
if (endsWith(abs_path_buff, "_graph.mindir") && origin_model.graph().parameter_size() == 0) {
int path_len = strlen(abs_path_buff) - strlen("graph.mindir");
if (strlen(abs_path_buff) < strlen("graph.mindir")) {
MS_LOG(ERROR) << "The abs_path_buff length is less than 'graph.mindir'.";
return nullptr;
}
int path_len = SizeToInt(strlen(abs_path_buff) - strlen("graph.mindir"));
int ret = memcpy_s(abs_path, sizeof(abs_path), abs_path_buff, path_len);
if (ret != 0) {
MS_LOG(ERROR) << "Load MindIR occur memcpy_s error.";

View File

@ -275,6 +275,9 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
data = param["data"].asnumpy().reshape(-1)
data_list[key].append(data)
if not isinstance(ckpt_file_name, str):
raise ValueError("The ckpt_file_name must be string.")
ckpt_file_name = os.path.realpath(ckpt_file_name)
if async_save:
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list, enc_key, enc_mode), name="asyn_save_ckpt")
thr.start()
@ -344,10 +347,11 @@ def load(file_name, **kwargs):
"""
if not isinstance(file_name, str):
raise ValueError("The file name must be string.")
if not os.path.exists(file_name):
raise ValueError("The file does not exist.")
if not file_name.endswith(".mindir"):
raise ValueError("The MindIR should end with mindir, please input the correct file name.")
if not os.path.exists(file_name):
raise ValueError("The file does not exist.")
file_name = os.path.realpath(file_name)
logger.info("Execute the process of loading mindir.")
if 'dec_key' in kwargs.keys():
@ -479,6 +483,7 @@ def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
if ckpt_file_name[-5:] != ".ckpt":
raise ValueError("Please input the correct checkpoint file name.")
ckpt_file_name = os.path.realpath(ckpt_file_name)
if filter_prefix is not None:
if not isinstance(filter_prefix, (str, list, tuple)):
@ -597,6 +602,9 @@ def _save_graph(network, file_name):
"""
logger.info("Execute the process of saving graph.")
if not isinstance(file_name, str):
raise ValueError("The ckpt_file_name must be string.")
file_name = os.path.realpath(file_name)
graph_pb = network.get_func_graph_proto()
if graph_pb:
with open(file_name, "wb") as f:
@ -727,6 +735,7 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
check_input_data(*inputs, data_class=Tensor)
if not isinstance(file_name, str):
raise ValueError("Args file_name {} must be string, please check it".format(file_name))
file_name = os.path.realpath(file_name)
Validator.check_file_name_by_regular(file_name)
net = _quant_export(net, *inputs, file_format=file_format, **kwargs)