!18168 add ckpt file authority
Merge pull request !18168 from changzherui/add_file_authority
This commit is contained in:
commit
0cbc9fa155
|
@ -184,6 +184,10 @@ std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(std::vector<std::string> fil
|
|||
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite, const unsigned char *dec_key,
|
||||
const size_t key_len, const std::string &dec_mode) {
|
||||
if (file_name.length() > PATH_MAX) {
|
||||
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
|
||||
return nullptr;
|
||||
}
|
||||
const char *file_path = reinterpret_cast<const char *>(file_name.c_str());
|
||||
char abs_path_buff[PATH_MAX];
|
||||
char abs_path[PATH_MAX];
|
||||
|
|
|
@ -463,10 +463,7 @@ class CheckpointManager:
|
|||
for filename in files:
|
||||
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix):
|
||||
mid_name = filename[len(prefix):-5]
|
||||
flag = True
|
||||
for char in mid_name:
|
||||
if char.isalpha():
|
||||
flag = False
|
||||
flag = not (True in [char.isalpha() for char in mid_name])
|
||||
if flag:
|
||||
self._ckpoint_filelist.append(directory + '/' + filename)
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ from mindspore.common.api import _executor
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore._checkparam import check_input_data, Validator
|
||||
from mindspore.compression.export import quant_export
|
||||
from mindspore.parallel._tensor import _load_tensor
|
||||
from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
|
||||
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
from .._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file
|
||||
|
@ -632,10 +632,7 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
|
|||
# while any dim is not equal to -1, means param is split and needs to be merged
|
||||
# pipeline parallel need to be supported here later
|
||||
if mp_weight:
|
||||
if opt_shard_group:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, True)
|
||||
else:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group))
|
||||
elif opt_shard_group:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
|
||||
|
@ -758,7 +755,10 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|||
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
|
||||
if not file_name.endswith('.air'):
|
||||
file_name += ".air"
|
||||
if os.path.exists(file_name):
|
||||
os.chmod(file_name, stat.S_IWUSR)
|
||||
_executor.export(file_name, graph_id)
|
||||
os.chmod(file_name, stat.S_IRUSR)
|
||||
elif file_format == 'ONNX':
|
||||
total_size = _calculation_net_size(net)
|
||||
if total_size > PROTO_LIMIT_SIZE:
|
||||
|
@ -769,9 +769,11 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|||
onnx_stream = _executor._get_func_graph_proto(net, graph_id)
|
||||
if not file_name.endswith('.onnx'):
|
||||
file_name += ".onnx"
|
||||
if os.path.exists(file_name):
|
||||
os.chmod(file_name, stat.S_IWUSR)
|
||||
with open(file_name, 'wb') as f:
|
||||
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
f.write(onnx_stream)
|
||||
os.chmod(file_name, stat.S_IRUSR)
|
||||
elif file_format == 'MINDIR':
|
||||
_save_mindir(net, file_name, *inputs, **kwargs)
|
||||
|
||||
|
@ -792,29 +794,10 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|||
net_dict = net.parameters_dict()
|
||||
model.ParseFromString(mindir_stream)
|
||||
|
||||
save_together = _mindir_save_together(net_dict, model)
|
||||
save_together = _save_together(net_dict, model)
|
||||
is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys()
|
||||
if save_together:
|
||||
for param_proto in model.graph.parameter:
|
||||
param_name = param_proto.name[param_proto.name.find(":")+1:]
|
||||
if param_name in net_dict.keys():
|
||||
param_data = net_dict[param_name].data.asnumpy().tobytes()
|
||||
param_proto.raw_data = param_data
|
||||
else:
|
||||
logger.error("The parameter %s in the graph are not in the network.", param_name)
|
||||
raise ValueError("The parameter in the graph must in the network.")
|
||||
if not file_name.endswith('.mindir'):
|
||||
file_name += ".mindir"
|
||||
current_path = os.path.abspath(file_name)
|
||||
dirname = os.path.dirname(current_path)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
with open(file_name, 'wb') as f:
|
||||
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
model_string = model.SerializeToString()
|
||||
if is_encrypt():
|
||||
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']),
|
||||
kwargs['enc_mode'])
|
||||
f.write(model_string)
|
||||
_save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs)
|
||||
else:
|
||||
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
|
||||
# save parameter
|
||||
|
@ -823,7 +806,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|||
file_prefix = file_prefix[:-7]
|
||||
current_path = os.path.abspath(file_name)
|
||||
dirname = os.path.dirname(current_path)
|
||||
data_path = dirname + "/" + file_prefix + "_variables"
|
||||
data_path = os.path.join(dirname, file_prefix + "_variables")
|
||||
if os.path.exists(data_path):
|
||||
shutil.rmtree(data_path)
|
||||
os.makedirs(data_path, exist_ok=True)
|
||||
|
@ -845,7 +828,9 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|||
data_size += sys.getsizeof(byte_data) / 1024
|
||||
break
|
||||
if data_size > TOTAL_SAVE:
|
||||
data_file_name = data_path + "/" + "data_" + str(index)
|
||||
data_file_name = os.path.join(data_path, "data_" + str(index))
|
||||
if os.path.exists(data_file_name):
|
||||
os.chmod(data_file_name, stat.S_IWUSR)
|
||||
with open(data_file_name, "ab") as f:
|
||||
os.chmod(data_file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
graph_string = graphproto.SerializeToString()
|
||||
|
@ -853,12 +838,15 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|||
graph_string = _encrypt(graph_string, len(graph_string), kwargs['enc_key'],
|
||||
len(kwargs['enc_key']), kwargs['enc_mode'])
|
||||
f.write(graph_string)
|
||||
os.chmod(data_file_name, stat.S_IRUSR)
|
||||
index += 1
|
||||
data_size = 0
|
||||
del graphproto.parameter[:]
|
||||
|
||||
if graphproto.parameter:
|
||||
data_file_name = data_path + "/" + "data_" + str(index)
|
||||
data_file_name = os.path.join(data_path, "data_" + str(index))
|
||||
if os.path.exists(data_file_name):
|
||||
os.chmod(data_file_name, stat.S_IWUSR)
|
||||
with open(data_file_name, "ab") as f:
|
||||
os.chmod(data_file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
graph_string = graphproto.SerializeToString()
|
||||
|
@ -866,10 +854,13 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|||
graph_string = _encrypt(graph_string, len(graph_string), kwargs['enc_key'], len(kwargs['enc_key']),
|
||||
kwargs['enc_mode'])
|
||||
f.write(graph_string)
|
||||
os.chmod(data_file_name, stat.S_IRUSR)
|
||||
|
||||
# save graph
|
||||
del model.graph.parameter[:]
|
||||
graph_file_name = dirname + "/" + file_prefix + "_graph.mindir"
|
||||
graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
|
||||
if os.path.exists(graph_file_name):
|
||||
os.chmod(graph_file_name, stat.S_IWUSR)
|
||||
with open(graph_file_name, 'wb') as f:
|
||||
os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
model_string = model.SerializeToString()
|
||||
|
@ -877,9 +868,37 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|||
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']),
|
||||
kwargs['enc_mode'])
|
||||
f.write(model_string)
|
||||
os.chmod(graph_file_name, stat.S_IRUSR)
|
||||
|
||||
|
||||
def _mindir_save_together(net_dict, model):
|
||||
def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
||||
"""Save graph and parameter together."""
|
||||
for param_proto in model.graph.parameter:
|
||||
param_name = param_proto.name[param_proto.name.find(":") + 1:]
|
||||
if param_name in net_dict.keys():
|
||||
param_data = net_dict[param_name].data.asnumpy().tobytes()
|
||||
param_proto.raw_data = param_data
|
||||
else:
|
||||
logger.error("The parameter %s in the graph are not in the network.", param_name)
|
||||
raise ValueError("The parameter in the graph must in the network.")
|
||||
if not file_name.endswith('.mindir'):
|
||||
file_name += ".mindir"
|
||||
current_path = os.path.abspath(file_name)
|
||||
dirname = os.path.dirname(current_path)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
if os.path.exists(file_name):
|
||||
os.chmod(file_name, stat.S_IWUSR)
|
||||
with open(file_name, 'wb') as f:
|
||||
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
model_string = model.SerializeToString()
|
||||
if is_encrypt():
|
||||
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']),
|
||||
kwargs['enc_mode'])
|
||||
f.write(model_string)
|
||||
os.chmod(file_name, stat.S_IRUSR)
|
||||
|
||||
|
||||
def _save_together(net_dict, model):
|
||||
"""Whether graph and parameter save together during save mindir model."""
|
||||
data_total = 0
|
||||
for param_proto in model.graph.parameter:
|
||||
|
@ -1059,7 +1078,6 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
|
|||
merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map)
|
||||
|
||||
else:
|
||||
from mindspore.parallel._tensor import _get_tensor_strategy, _get_tensor_slice_index
|
||||
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
||||
|
||||
slice_count = 1
|
||||
|
|
Loading…
Reference in New Issue