!46159 support mapparameter save and load
Merge pull request !46159 from changzherui/support_mapparameter
This commit is contained in:
commit
7836767cad
|
@ -19,7 +19,10 @@ syntax = "proto2";
|
|||
message Checkpoint {
|
||||
message Value {
|
||||
required string tag = 1;
|
||||
required TensorProto tensor = 2;
|
||||
oneof value{
|
||||
TensorProto tensor = 2;
|
||||
MapTensorProto maptensor = 3;
|
||||
}
|
||||
}
|
||||
repeated Value value = 1;
|
||||
}
|
||||
|
@ -33,3 +36,8 @@ message TensorProto {
|
|||
// The data of the tensor.
|
||||
required bytes tensor_content = 3;
|
||||
}
|
||||
|
||||
|
||||
message MapTensorProto {
|
||||
repeated TensorProto tensor = 1;
|
||||
}
|
||||
|
|
|
@ -53,6 +53,7 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.common._utils import is_shape_unknown
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
from mindspore.compression.export import quant_export
|
||||
from mindspore.experimental import MapParameter
|
||||
from mindspore.parallel._cell_wrapper import get_allgather_cell
|
||||
from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
|
||||
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
|
||||
|
@ -63,6 +64,7 @@ from mindspore.train._utils import read_proto
|
|||
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir
|
||||
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
|
||||
|
||||
|
||||
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
||||
"Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
|
||||
"Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
|
||||
|
@ -205,6 +207,7 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
|
|||
logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")
|
||||
|
||||
|
||||
# pylint: disable=too-many-function-args
|
||||
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
||||
"""Execute the process of saving checkpoint into file."""
|
||||
try:
|
||||
|
@ -217,6 +220,9 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
|||
plain_data = BytesIO()
|
||||
|
||||
for name, value in data_list.items():
|
||||
if value[0] == "mapparameter":
|
||||
_write_mapparameter(name, value, f)
|
||||
continue
|
||||
data_size = value[2].nbytes / 1024
|
||||
if data_size > SLICE_SIZE:
|
||||
slice_count = math.ceil(data_size / SLICE_SIZE)
|
||||
|
@ -254,6 +260,20 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
|||
raise e
|
||||
|
||||
|
||||
def _write_mapparameter(name, value, f):
|
||||
"""Write map parameter into protobuf file."""
|
||||
checkpoint_list = Checkpoint()
|
||||
param_value = checkpoint_list.value.add()
|
||||
param_value.tag = name
|
||||
map_tensor = param_value.maptensor
|
||||
for v in value[1:]:
|
||||
tensor = map_tensor.tensor.add()
|
||||
tensor.dims.extend(v[0])
|
||||
tensor.tensor_type = v[1]
|
||||
tensor.tensor_content = v[2].tobytes()
|
||||
f.write(checkpoint_list.SerializeToString())
|
||||
|
||||
|
||||
def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
||||
"""Check save_obj and ckpt_file_name for save_checkpoint."""
|
||||
if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
|
||||
|
@ -322,6 +342,14 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|||
param_list = []
|
||||
for (key, value) in param_dict.items():
|
||||
each_param = {"name": key}
|
||||
if isinstance(value, MapParameter):
|
||||
param_data = []
|
||||
for export_data in value.export_data():
|
||||
param_data.append(Tensor(export_data))
|
||||
each_param["data"] = param_data
|
||||
param_list.append(each_param)
|
||||
continue
|
||||
|
||||
param_data = Tensor(value.data.asnumpy())
|
||||
|
||||
# in automatic model parallel scenario, some parameters were split to all the devices,
|
||||
|
@ -346,6 +374,9 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|||
for param in save_obj:
|
||||
key = param["name"]
|
||||
data_list[key] = []
|
||||
if isinstance(param["data"], list):
|
||||
_save_mapparameter(data_list, param)
|
||||
continue
|
||||
if isinstance(param["data"], str):
|
||||
data_list[key].append([0])
|
||||
data_list[key].append('str')
|
||||
|
@ -376,6 +407,22 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|||
logger.info("Saving checkpoint process is finished.")
|
||||
|
||||
|
||||
def _save_mapparameter(data_list, param):
|
||||
"""Save map parameter into save_obj."""
|
||||
data_list[param["name"]].append("mapparameter")
|
||||
for value in param["data"]:
|
||||
dims = []
|
||||
tmp_list = []
|
||||
for dim in value.shape:
|
||||
dims.append(dim)
|
||||
tmp_list.append(dims)
|
||||
tensor_type = str(value.dtype)
|
||||
tmp_list.append(tensor_type)
|
||||
data = value.asnumpy().reshape(-1)
|
||||
tmp_list.append(data)
|
||||
data_list[param["name"]].append(tmp_list)
|
||||
|
||||
|
||||
def _check_append_dict(append_dict):
|
||||
"""Check the argument append_dict for save_checkpoint."""
|
||||
if append_dict is None:
|
||||
|
@ -704,6 +751,9 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|||
for element_id, element in enumerate(checkpoint_list.value):
|
||||
if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
|
||||
continue
|
||||
if element.tensor.ByteSize() == 0:
|
||||
_load_mapparameter(element, parameter_dict)
|
||||
continue
|
||||
data = element.tensor.tensor_content
|
||||
data_type = element.tensor.tensor_type
|
||||
np_type = tensor_to_np_type.get(data_type)
|
||||
|
@ -828,6 +878,20 @@ def _whether_load_param(specify_prefix, filter_prefix, param_name):
|
|||
return whether_load
|
||||
|
||||
|
||||
def _load_mapparameter(element, parameter_dict):
|
||||
"""Load map parameter from ckpt file."""
|
||||
map_array = []
|
||||
for tensor in element.maptensor.tensor:
|
||||
data = tensor.tensor_content
|
||||
data_type = tensor.tensor_type
|
||||
np_type = tensor_to_np_type.get(data_type)
|
||||
element_data = np.frombuffer(data, np_type)
|
||||
dims = tensor.dims
|
||||
param_data = element_data.reshape(list(dims))
|
||||
map_array.append(param_data)
|
||||
parameter_dict[element.tag] = map_array
|
||||
|
||||
|
||||
def load_param_into_net(net, parameter_dict, strict_load=False):
|
||||
"""
|
||||
Load parameters into network, return parameter list that are not loaded in the network.
|
||||
|
|
|
@ -16,7 +16,7 @@ import os
|
|||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore import context, Tensor, Parameter, save_checkpoint, load_checkpoint
|
||||
from mindspore.experimental import MapParameter
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
|
@ -86,3 +86,37 @@ def test_maptensor_put_get_export(ms_type):
|
|||
print("data:", data)
|
||||
assert len(data) == 3
|
||||
assert len(data[0]) == 4
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('ms_type', [ms.int32, ms.int64])
|
||||
def test_mapparameter_ckpt_save_load(ms_type):
|
||||
"""
|
||||
Feature: MapParameter
|
||||
Description: Test MapParameter, test save and load
|
||||
Expectation: IR graph with MapParameter created without exceptions.
|
||||
"""
|
||||
class MyNet(nn.Cell):
|
||||
def __init__(self, ms_type):
|
||||
nn.Cell.__init__(self)
|
||||
self.m = MapParameter(key_dtype=ms_type, value_dtype=ms.float32, value_shape=(3,))
|
||||
self.keys = Tensor([1, 2], dtype=ms_type)
|
||||
self.values = Tensor([[11, 11, 11], [22, 22, 22]], dtype=ms.float32)
|
||||
|
||||
def construct(self, ms_type):
|
||||
self.m[self.keys] = self.values
|
||||
key1 = Tensor([3], dtype=ms_type)
|
||||
value1 = self.m.get(key1, True)
|
||||
key2 = Tensor([4], dtype=ms_type)
|
||||
value2 = self.m.get(key2, True)
|
||||
return value1, value2, self.m
|
||||
|
||||
net = MyNet(ms_type)
|
||||
net(ms_type)
|
||||
file_name = "map_parameter.ckpt"
|
||||
save_checkpoint(net, file_name)
|
||||
assert os.path.exists(file_name)
|
||||
load_checkpoint(file_name)
|
||||
os.remove(file_name)
|
||||
|
|
Loading…
Reference in New Issue