!46159 support mapparameter save and load

Merge pull request !46159 from changzherui/support_mapparameter
This commit is contained in:
i-robot 2022-11-30 10:31:56 +00:00 committed by Gitee
commit 7836767cad
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 108 additions and 2 deletions

View File

@ -19,7 +19,10 @@ syntax = "proto2";
message Checkpoint {
message Value {
required string tag = 1;
required TensorProto tensor = 2;
oneof value{
TensorProto tensor = 2;
MapTensorProto maptensor = 3;
}
}
repeated Value value = 1;
}
@ -33,3 +36,8 @@ message TensorProto {
// The data of the tensor.
required bytes tensor_content = 3;
}
message MapTensorProto {
repeated TensorProto tensor = 1;
}

View File

@ -53,6 +53,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common._utils import is_shape_unknown
from mindspore.communication.management import get_rank, get_group_size
from mindspore.compression.export import quant_export
from mindspore.experimental import MapParameter
from mindspore.parallel._cell_wrapper import get_allgather_cell
from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
@ -63,6 +64,7 @@ from mindspore.train._utils import read_proto
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
"Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
"Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
@ -205,6 +207,7 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")
# pylint: disable=too-many-function-args
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
"""Execute the process of saving checkpoint into file."""
try:
@ -217,6 +220,9 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
plain_data = BytesIO()
for name, value in data_list.items():
if value[0] == "mapparameter":
_write_mapparameter(name, value, f)
continue
data_size = value[2].nbytes / 1024
if data_size > SLICE_SIZE:
slice_count = math.ceil(data_size / SLICE_SIZE)
@ -254,6 +260,20 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
raise e
def _write_mapparameter(name, value, f):
"""Write map parameter into protobuf file."""
checkpoint_list = Checkpoint()
param_value = checkpoint_list.value.add()
param_value.tag = name
map_tensor = param_value.maptensor
for v in value[1:]:
tensor = map_tensor.tensor.add()
tensor.dims.extend(v[0])
tensor.tensor_type = v[1]
tensor.tensor_content = v[2].tobytes()
f.write(checkpoint_list.SerializeToString())
def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
"""Check save_obj and ckpt_file_name for save_checkpoint."""
if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
@ -322,6 +342,14 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
param_list = []
for (key, value) in param_dict.items():
each_param = {"name": key}
if isinstance(value, MapParameter):
param_data = []
for export_data in value.export_data():
param_data.append(Tensor(export_data))
each_param["data"] = param_data
param_list.append(each_param)
continue
param_data = Tensor(value.data.asnumpy())
# in automatic model parallel scenario, some parameters were split to all the devices,
@ -346,6 +374,9 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
for param in save_obj:
key = param["name"]
data_list[key] = []
if isinstance(param["data"], list):
_save_mapparameter(data_list, param)
continue
if isinstance(param["data"], str):
data_list[key].append([0])
data_list[key].append('str')
@ -376,6 +407,22 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
logger.info("Saving checkpoint process is finished.")
def _save_mapparameter(data_list, param):
"""Save map parameter into save_obj."""
data_list[param["name"]].append("mapparameter")
for value in param["data"]:
dims = []
tmp_list = []
for dim in value.shape:
dims.append(dim)
tmp_list.append(dims)
tensor_type = str(value.dtype)
tmp_list.append(tensor_type)
data = value.asnumpy().reshape(-1)
tmp_list.append(data)
data_list[param["name"]].append(tmp_list)
def _check_append_dict(append_dict):
"""Check the argument append_dict for save_checkpoint."""
if append_dict is None:
@ -704,6 +751,9 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
for element_id, element in enumerate(checkpoint_list.value):
if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
continue
if element.tensor.ByteSize() == 0:
_load_mapparameter(element, parameter_dict)
continue
data = element.tensor.tensor_content
data_type = element.tensor.tensor_type
np_type = tensor_to_np_type.get(data_type)
@ -828,6 +878,20 @@ def _whether_load_param(specify_prefix, filter_prefix, param_name):
return whether_load
def _load_mapparameter(element, parameter_dict):
"""Load map parameter from ckpt file."""
map_array = []
for tensor in element.maptensor.tensor:
data = tensor.tensor_content
data_type = tensor.tensor_type
np_type = tensor_to_np_type.get(data_type)
element_data = np.frombuffer(data, np_type)
dims = tensor.dims
param_data = element_data.reshape(list(dims))
map_array.append(param_data)
parameter_dict[element.tag] = map_array
def load_param_into_net(net, parameter_dict, strict_load=False):
"""
Load parameters into network, return parameter list that are not loaded in the network.

View File

@ -16,7 +16,7 @@ import os
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import context, Tensor, Parameter
from mindspore import context, Tensor, Parameter, save_checkpoint, load_checkpoint
from mindspore.experimental import MapParameter
from mindspore.common.initializer import initializer
@ -86,3 +86,37 @@ def test_maptensor_put_get_export(ms_type):
print("data:", data)
assert len(data) == 3
assert len(data[0]) == 4
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('ms_type', [ms.int32, ms.int64])
def test_mapparameter_ckpt_save_load(ms_type):
"""
Feature: MapParameter
Description: Test MapParameter, test save and load
Expectation: IR graph with MapParameter created without exceptions.
"""
class MyNet(nn.Cell):
def __init__(self, ms_type):
nn.Cell.__init__(self)
self.m = MapParameter(key_dtype=ms_type, value_dtype=ms.float32, value_shape=(3,))
self.keys = Tensor([1, 2], dtype=ms_type)
self.values = Tensor([[11, 11, 11], [22, 22, 22]], dtype=ms.float32)
def construct(self, ms_type):
self.m[self.keys] = self.values
key1 = Tensor([3], dtype=ms_type)
value1 = self.m.get(key1, True)
key2 = Tensor([4], dtype=ms_type)
value2 = self.m.get(key2, True)
return value1, value2, self.m
net = MyNet(ms_type)
net(ms_type)
file_name = "map_parameter.ckpt"
save_checkpoint(net, file_name)
assert os.path.exists(file_name)
load_checkpoint(file_name)
os.remove(file_name)