!20340 read or write file interface code check
Merge pull request !20340 from zhangbuxue/write_or_read_file_interface_code_check
This commit is contained in:
commit
c168ecce09
|
@ -424,7 +424,6 @@ bool ArithmeticCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const
|
|||
SquaredDifference(input1, input2, output);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -213,7 +213,6 @@ bool ArithmeticLogicCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs,
|
|||
LogicalOr(input1, input2, output);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -174,8 +174,7 @@ std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(std::vector<std::string> fil
|
|||
const std::string &dec_mode, bool inc_load) {
|
||||
std::vector<std::shared_ptr<FuncGraph>> funcgraph_vec;
|
||||
MS_LOG(DEBUG) << "Load multiple MindIR files.";
|
||||
for (size_t i = 0; i < file_names.size(); ++i) {
|
||||
std::string file_name = file_names[i];
|
||||
for (const auto &file_name : file_names) {
|
||||
MS_LOG(DEBUG) << "Load " << file_name;
|
||||
funcgraph_vec.push_back(LoadMindIR(file_name, is_lite, dec_key, key_len, dec_mode, inc_load));
|
||||
}
|
||||
|
@ -188,10 +187,9 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
|
|||
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
|
||||
return nullptr;
|
||||
}
|
||||
const char *file_path = reinterpret_cast<const char *>(file_name.c_str());
|
||||
const char *file_path = file_name.c_str();
|
||||
char abs_path_buff[PATH_MAX];
|
||||
char abs_path[PATH_MAX];
|
||||
|
||||
vector<string> files;
|
||||
|
||||
#ifdef _WIN32
|
||||
|
@ -208,7 +206,11 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
|
|||
}
|
||||
// Load parameter into graph
|
||||
if (endsWith(abs_path_buff, "_graph.mindir") && origin_model.graph().parameter_size() == 0) {
|
||||
int path_len = strlen(abs_path_buff) - strlen("graph.mindir");
|
||||
if (strlen(abs_path_buff) < strlen("graph.mindir")) {
|
||||
MS_LOG(ERROR) << "The abs_path_buff length is less than 'graph.mindir'.";
|
||||
return nullptr;
|
||||
}
|
||||
int path_len = SizeToInt(strlen(abs_path_buff) - strlen("graph.mindir"));
|
||||
int ret = memcpy_s(abs_path, sizeof(abs_path), abs_path_buff, path_len);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Load MindIR occur memcpy_s error.";
|
||||
|
|
|
@ -275,6 +275,9 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|||
data = param["data"].asnumpy().reshape(-1)
|
||||
data_list[key].append(data)
|
||||
|
||||
if not isinstance(ckpt_file_name, str):
|
||||
raise ValueError("The ckpt_file_name must be string.")
|
||||
ckpt_file_name = os.path.realpath(ckpt_file_name)
|
||||
if async_save:
|
||||
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list, enc_key, enc_mode), name="asyn_save_ckpt")
|
||||
thr.start()
|
||||
|
@ -344,10 +347,11 @@ def load(file_name, **kwargs):
|
|||
"""
|
||||
if not isinstance(file_name, str):
|
||||
raise ValueError("The file name must be string.")
|
||||
if not os.path.exists(file_name):
|
||||
raise ValueError("The file does not exist.")
|
||||
if not file_name.endswith(".mindir"):
|
||||
raise ValueError("The MindIR should end with mindir, please input the correct file name.")
|
||||
if not os.path.exists(file_name):
|
||||
raise ValueError("The file does not exist.")
|
||||
file_name = os.path.realpath(file_name)
|
||||
|
||||
logger.info("Execute the process of loading mindir.")
|
||||
if 'dec_key' in kwargs.keys():
|
||||
|
@ -479,6 +483,7 @@ def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
|
|||
|
||||
if ckpt_file_name[-5:] != ".ckpt":
|
||||
raise ValueError("Please input the correct checkpoint file name.")
|
||||
ckpt_file_name = os.path.realpath(ckpt_file_name)
|
||||
|
||||
if filter_prefix is not None:
|
||||
if not isinstance(filter_prefix, (str, list, tuple)):
|
||||
|
@ -597,6 +602,9 @@ def _save_graph(network, file_name):
|
|||
"""
|
||||
logger.info("Execute the process of saving graph.")
|
||||
|
||||
if not isinstance(file_name, str):
|
||||
raise ValueError("The ckpt_file_name must be string.")
|
||||
file_name = os.path.realpath(file_name)
|
||||
graph_pb = network.get_func_graph_proto()
|
||||
if graph_pb:
|
||||
with open(file_name, "wb") as f:
|
||||
|
@ -727,6 +735,7 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
|||
check_input_data(*inputs, data_class=Tensor)
|
||||
if not isinstance(file_name, str):
|
||||
raise ValueError("Args file_name {} must be string, please check it".format(file_name))
|
||||
file_name = os.path.realpath(file_name)
|
||||
|
||||
Validator.check_file_name_by_regular(file_name)
|
||||
net = _quant_export(net, *inputs, file_format=file_format, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue