!18168 add ckpt file authority

Merge pull request !18168 from changzherui/add_file_authority
This commit is contained in:
i-robot 2021-06-17 14:16:54 +08:00 committed by Gitee
commit 0cbc9fa155
3 changed files with 56 additions and 37 deletions

View File

@ -184,6 +184,10 @@ std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(std::vector<std::string> fil
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite, const unsigned char *dec_key,
const size_t key_len, const std::string &dec_mode) {
if (file_name.length() > PATH_MAX) {
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
return nullptr;
}
const char *file_path = reinterpret_cast<const char *>(file_name.c_str());
char abs_path_buff[PATH_MAX];
char abs_path[PATH_MAX];

View File

@ -463,10 +463,7 @@ class CheckpointManager:
for filename in files:
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix):
mid_name = filename[len(prefix):-5]
flag = True
for char in mid_name:
if char.isalpha():
flag = False
flag = not (True in [char.isalpha() for char in mid_name])
if flag:
self._ckpoint_filelist.append(directory + '/' + filename)

View File

@ -41,7 +41,7 @@ from mindspore.common.api import _executor
from mindspore.common import dtype as mstype
from mindspore._checkparam import check_input_data, Validator
from mindspore.compression.export import quant_export
from mindspore.parallel._tensor import _load_tensor
from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
from mindspore.communication.management import get_rank, get_group_size
from .._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file
@ -632,10 +632,7 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
# while any dim is not equal to -1, means param is split and needs to be merged
# pipeline parallel need to be supported here later
if mp_weight:
if opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, True)
else:
allgather_net = get_allgather_cell(opt_shard_group, False)
allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group))
elif opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, False)
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
@ -758,7 +755,10 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
if not file_name.endswith('.air'):
file_name += ".air"
if os.path.exists(file_name):
os.chmod(file_name, stat.S_IWUSR)
_executor.export(file_name, graph_id)
os.chmod(file_name, stat.S_IRUSR)
elif file_format == 'ONNX':
total_size = _calculation_net_size(net)
if total_size > PROTO_LIMIT_SIZE:
@ -769,9 +769,11 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
onnx_stream = _executor._get_func_graph_proto(net, graph_id)
if not file_name.endswith('.onnx'):
file_name += ".onnx"
if os.path.exists(file_name):
os.chmod(file_name, stat.S_IWUSR)
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
f.write(onnx_stream)
os.chmod(file_name, stat.S_IRUSR)
elif file_format == 'MINDIR':
_save_mindir(net, file_name, *inputs, **kwargs)
@ -792,29 +794,10 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
net_dict = net.parameters_dict()
model.ParseFromString(mindir_stream)
save_together = _mindir_save_together(net_dict, model)
save_together = _save_together(net_dict, model)
is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys()
if save_together:
for param_proto in model.graph.parameter:
param_name = param_proto.name[param_proto.name.find(":")+1:]
if param_name in net_dict.keys():
param_data = net_dict[param_name].data.asnumpy().tobytes()
param_proto.raw_data = param_data
else:
logger.error("The parameter %s in the graph are not in the network.", param_name)
raise ValueError("The parameter in the graph must in the network.")
if not file_name.endswith('.mindir'):
file_name += ".mindir"
current_path = os.path.abspath(file_name)
dirname = os.path.dirname(current_path)
os.makedirs(dirname, exist_ok=True)
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
model_string = model.SerializeToString()
if is_encrypt():
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']),
kwargs['enc_mode'])
f.write(model_string)
_save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs)
else:
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
# save parameter
@ -823,7 +806,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
file_prefix = file_prefix[:-7]
current_path = os.path.abspath(file_name)
dirname = os.path.dirname(current_path)
data_path = dirname + "/" + file_prefix + "_variables"
data_path = os.path.join(dirname, file_prefix + "_variables")
if os.path.exists(data_path):
shutil.rmtree(data_path)
os.makedirs(data_path, exist_ok=True)
@ -845,7 +828,9 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
data_size += sys.getsizeof(byte_data) / 1024
break
if data_size > TOTAL_SAVE:
data_file_name = data_path + "/" + "data_" + str(index)
data_file_name = os.path.join(data_path, "data_" + str(index))
if os.path.exists(data_file_name):
os.chmod(data_file_name, stat.S_IWUSR)
with open(data_file_name, "ab") as f:
os.chmod(data_file_name, stat.S_IRUSR | stat.S_IWUSR)
graph_string = graphproto.SerializeToString()
@ -853,12 +838,15 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
graph_string = _encrypt(graph_string, len(graph_string), kwargs['enc_key'],
len(kwargs['enc_key']), kwargs['enc_mode'])
f.write(graph_string)
os.chmod(data_file_name, stat.S_IRUSR)
index += 1
data_size = 0
del graphproto.parameter[:]
if graphproto.parameter:
data_file_name = data_path + "/" + "data_" + str(index)
data_file_name = os.path.join(data_path, "data_" + str(index))
if os.path.exists(data_file_name):
os.chmod(data_file_name, stat.S_IWUSR)
with open(data_file_name, "ab") as f:
os.chmod(data_file_name, stat.S_IRUSR | stat.S_IWUSR)
graph_string = graphproto.SerializeToString()
@ -866,10 +854,13 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
graph_string = _encrypt(graph_string, len(graph_string), kwargs['enc_key'], len(kwargs['enc_key']),
kwargs['enc_mode'])
f.write(graph_string)
os.chmod(data_file_name, stat.S_IRUSR)
# save graph
del model.graph.parameter[:]
graph_file_name = dirname + "/" + file_prefix + "_graph.mindir"
graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
if os.path.exists(graph_file_name):
os.chmod(graph_file_name, stat.S_IWUSR)
with open(graph_file_name, 'wb') as f:
os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR)
model_string = model.SerializeToString()
@ -877,9 +868,37 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']),
kwargs['enc_mode'])
f.write(model_string)
os.chmod(graph_file_name, stat.S_IRUSR)
def _mindir_save_together(net_dict, model):
def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
"""Save graph and parameter together."""
for param_proto in model.graph.parameter:
param_name = param_proto.name[param_proto.name.find(":") + 1:]
if param_name in net_dict.keys():
param_data = net_dict[param_name].data.asnumpy().tobytes()
param_proto.raw_data = param_data
else:
logger.error("The parameter %s in the graph are not in the network.", param_name)
raise ValueError("The parameter in the graph must in the network.")
if not file_name.endswith('.mindir'):
file_name += ".mindir"
current_path = os.path.abspath(file_name)
dirname = os.path.dirname(current_path)
os.makedirs(dirname, exist_ok=True)
if os.path.exists(file_name):
os.chmod(file_name, stat.S_IWUSR)
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
model_string = model.SerializeToString()
if is_encrypt():
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']),
kwargs['enc_mode'])
f.write(model_string)
os.chmod(file_name, stat.S_IRUSR)
def _save_together(net_dict, model):
"""Whether graph and parameter save together during save mindir model."""
data_total = 0
for param_proto in model.graph.parameter:
@ -1059,7 +1078,6 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map)
else:
from mindspore.parallel._tensor import _get_tensor_strategy, _get_tensor_slice_index
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
slice_count = 1