forked from mindspore-Ecosystem/mindspore
!3273 Optimized checkpoint save slice tensor
Merge pull request !3273 from changzherui/save_slice_tensor
This commit is contained in:
commit
097b77c3b8
|
@ -15,6 +15,7 @@
|
|||
"""Model and parameters serialization."""
|
||||
import os
|
||||
import stat
|
||||
import math
|
||||
from threading import Thread, Lock
|
||||
import numpy as np
|
||||
|
||||
|
@ -42,6 +43,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
|
|||
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
|
||||
|
||||
_ckpt_mutex = Lock()
|
||||
SLICE_SIZE = 512 * 1024 * 1024
|
||||
|
||||
|
||||
def _special_process_par(par, new_par):
|
||||
"""
|
||||
|
@ -105,26 +108,38 @@ def _update_param(param, new_param):
|
|||
|
||||
def _exec_save(ckpt_file_name, data_list):
|
||||
"""Execute save checkpoint into file process."""
|
||||
checkpoint_list = Checkpoint()
|
||||
|
||||
try:
|
||||
with _ckpt_mutex:
|
||||
for name, value in data_list.items():
|
||||
param_value = checkpoint_list.value.add()
|
||||
param_value.tag = name
|
||||
param_tensor = param_value.tensor
|
||||
param_tensor.dims.extend(value[0])
|
||||
param_tensor.tensor_type = value[1]
|
||||
param_tensor.tensor_content = value[2].tostring()
|
||||
if os.path.exists(ckpt_file_name):
|
||||
os.remove(ckpt_file_name)
|
||||
with open(ckpt_file_name, "ab") as f:
|
||||
for name, value in data_list.items():
|
||||
data_size = value[2].nbytes
|
||||
if data_size > SLICE_SIZE:
|
||||
slice_count = math.ceil(data_size / SLICE_SIZE)
|
||||
param_slice_list = np.array_split(value[2], slice_count)
|
||||
else:
|
||||
param_slice_list = [value[2]]
|
||||
|
||||
with open(ckpt_file_name, "wb") as f:
|
||||
f.write(checkpoint_list.SerializeToString())
|
||||
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
||||
for param_slice in param_slice_list:
|
||||
checkpoint_list = Checkpoint()
|
||||
param_value = checkpoint_list.value.add()
|
||||
param_value.tag = name
|
||||
param_tensor = param_value.tensor
|
||||
param_tensor.dims.extend(value[0])
|
||||
param_tensor.tensor_type = value[1]
|
||||
param_tensor.tensor_content = param_slice.tostring()
|
||||
|
||||
f.write(checkpoint_list.SerializeToString())
|
||||
|
||||
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
||||
|
||||
except BaseException as e:
|
||||
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
|
||||
raise RuntimeError(e.__str__())
|
||||
|
||||
|
||||
def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
|
||||
"""
|
||||
Saves checkpoint info to a specified file.
|
||||
|
@ -206,28 +221,37 @@ def load_checkpoint(ckpt_file_name, net=None):
|
|||
|
||||
parameter_dict = {}
|
||||
try:
|
||||
element_id = 0
|
||||
param_data_list = []
|
||||
for element in checkpoint_list.value:
|
||||
data = element.tensor.tensor_content
|
||||
data_type = element.tensor.tensor_type
|
||||
np_type = tensor_to_np_type[data_type]
|
||||
ms_type = tensor_to_ms_type[data_type]
|
||||
param_data = np.fromstring(data, np_type)
|
||||
dims = element.tensor.dims
|
||||
element_data = np.frombuffer(data, np_type)
|
||||
param_data_list.append(element_data)
|
||||
if (element_id == len(checkpoint_list.value) - 1) or \
|
||||
(element.tag != checkpoint_list.value[element_id + 1].tag):
|
||||
param_data = np.concatenate((param_data_list), axis=0)
|
||||
param_data_list.clear()
|
||||
dims = element.tensor.dims
|
||||
|
||||
if dims == [0]:
|
||||
if 'Float' in data_type:
|
||||
param_data = float(param_data[0])
|
||||
elif 'Int' in data_type:
|
||||
param_data = int(param_data[0])
|
||||
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
||||
elif dims == [1]:
|
||||
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
||||
else:
|
||||
param_dim = []
|
||||
for dim in dims:
|
||||
param_dim.append(dim)
|
||||
param_value = param_data.reshape(param_dim)
|
||||
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
|
||||
if dims == [0]:
|
||||
if 'Float' in data_type:
|
||||
param_data = float(param_data[0])
|
||||
elif 'Int' in data_type:
|
||||
param_data = int(param_data[0])
|
||||
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
||||
elif dims == [1]:
|
||||
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
||||
else:
|
||||
param_dim = []
|
||||
for dim in dims:
|
||||
param_dim.append(dim)
|
||||
param_value = param_data.reshape(param_dim)
|
||||
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
|
||||
|
||||
element_id += 1
|
||||
|
||||
logger.info("Load checkpoint process finish.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue