!3273 Optimized checkpoint save slice tensor

Merge pull request !3273 from changzherui/save_slice_tensor
This commit is contained in:
mindspore-ci-bot 2020-07-22 16:21:43 +08:00 committed by Gitee
commit 097b77c3b8
1 changed files with 51 additions and 27 deletions

View File

@ -15,6 +15,7 @@
"""Model and parameters serialization."""
import os
import stat
import math
from threading import Thread, Lock
import numpy as np
@ -42,6 +43,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
_ckpt_mutex = Lock()
SLICE_SIZE = 512 * 1024 * 1024
def _special_process_par(par, new_par):
"""
@ -105,26 +108,38 @@ def _update_param(param, new_param):
def _exec_save(ckpt_file_name, data_list):
"""Execute save checkpoint into file process."""
checkpoint_list = Checkpoint()
try:
with _ckpt_mutex:
if os.path.exists(ckpt_file_name):
os.remove(ckpt_file_name)
with open(ckpt_file_name, "ab") as f:
for name, value in data_list.items():
data_size = value[2].nbytes
if data_size > SLICE_SIZE:
slice_count = math.ceil(data_size / SLICE_SIZE)
param_slice_list = np.array_split(value[2], slice_count)
else:
param_slice_list = [value[2]]
for param_slice in param_slice_list:
checkpoint_list = Checkpoint()
param_value = checkpoint_list.value.add()
param_value.tag = name
param_tensor = param_value.tensor
param_tensor.dims.extend(value[0])
param_tensor.tensor_type = value[1]
param_tensor.tensor_content = value[2].tostring()
param_tensor.tensor_content = param_slice.tostring()
with open(ckpt_file_name, "wb") as f:
f.write(checkpoint_list.SerializeToString())
os.chmod(ckpt_file_name, stat.S_IRUSR)
except BaseException as e:
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
raise RuntimeError(e.__str__())
def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
"""
Saves checkpoint info to a specified file.
@ -206,12 +221,19 @@ def load_checkpoint(ckpt_file_name, net=None):
parameter_dict = {}
try:
element_id = 0
param_data_list = []
for element in checkpoint_list.value:
data = element.tensor.tensor_content
data_type = element.tensor.tensor_type
np_type = tensor_to_np_type[data_type]
ms_type = tensor_to_ms_type[data_type]
param_data = np.fromstring(data, np_type)
element_data = np.frombuffer(data, np_type)
param_data_list.append(element_data)
if (element_id == len(checkpoint_list.value) - 1) or \
(element.tag != checkpoint_list.value[element_id + 1].tag):
param_data = np.concatenate((param_data_list), axis=0)
param_data_list.clear()
dims = element.tensor.dims
if dims == [0]:
@ -229,6 +251,8 @@ def load_checkpoint(ckpt_file_name, net=None):
param_value = param_data.reshape(param_dim)
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
element_id += 1
logger.info("Load checkpoint process finish.")
except BaseException as e: