add read air proto func
This commit is contained in:
parent
83b25e10e9
commit
e9ab8016eb
|
@ -192,7 +192,6 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
|
|||
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
|
||||
parameter_proto->set_name(param_name);
|
||||
SetParamToTensorProto(param, parameter_proto);
|
||||
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param());
|
||||
} else {
|
||||
mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
|
||||
input_proto->set_name(param_name);
|
||||
|
|
|
@ -183,7 +183,7 @@ class DatasetHelper:
|
|||
>>> train_dataset = create_custom_dataset()
|
||||
>>> set_helper = DatasetHelper(train_dataset, dataset_sink_mode=False)
|
||||
>>> for next_element in set_helper:
|
||||
>>> print(next_element)
|
||||
... print(next_element)
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1):
|
||||
|
|
|
@ -889,7 +889,6 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
|
|||
raise ValueError(f"The sliced_parameters length should be equal to device_count. "
|
||||
f"the sliced_parameters length is {len(sliced_data)} but device_count is {device_count}.")
|
||||
|
||||
merged_tensor = None
|
||||
if not param_split_shape:
|
||||
if not is_even:
|
||||
raise ValueError("The shape of every parameter in sliced_parameters should be the same "
|
||||
|
@ -1052,7 +1051,6 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|||
layerwise_parallel = sliced_parameters[0].layerwise_parallel
|
||||
requires_grad = sliced_parameters[0].requires_grad
|
||||
sliced_data = [parameter.data.asnumpy() for parameter in sliced_parameters]
|
||||
merged_parameter = None
|
||||
|
||||
if not strategy:
|
||||
merged_tensor = Tensor(np.concatenate(sliced_data))
|
||||
|
@ -1121,7 +1119,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|||
param_rank = rank_list[param.name][0]
|
||||
skip_merge_split = rank_list[param.name][1]
|
||||
for rank in param_rank:
|
||||
sliced_param = _load_single_param(checkpoint_filenames[rank], param.name)
|
||||
sliced_param = load_checkpoint(checkpoint_filenames[rank])[param.name]
|
||||
sliced_params.append(sliced_param)
|
||||
if skip_merge_split:
|
||||
split_param = sliced_params[0]
|
||||
|
@ -1213,59 +1211,3 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
|
|||
layerwise_parallel = merged_param.layerwise_parallel
|
||||
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
|
||||
return split_param
|
||||
|
||||
|
||||
def _load_single_param(ckpt_file_name, param_name):
|
||||
"""Load a parameter from checkpoint."""
|
||||
checkpoint_list = Checkpoint()
|
||||
|
||||
try:
|
||||
with open(ckpt_file_name, "rb") as f:
|
||||
pb_content = f.read()
|
||||
checkpoint_list.ParseFromString(pb_content)
|
||||
except BaseException as e:
|
||||
logger.error("Failed to read the checkpoint file `%s` during load single parameter,"
|
||||
" please check the correct of the file.", ckpt_file_name)
|
||||
raise ValueError(e.__str__())
|
||||
|
||||
parameter = None
|
||||
try:
|
||||
param_data_list = []
|
||||
for element_id, element in enumerate(checkpoint_list.value):
|
||||
if element.tag != param_name:
|
||||
continue
|
||||
data = element.tensor.tensor_content
|
||||
data_type = element.tensor.tensor_type
|
||||
np_type = tensor_to_np_type[data_type]
|
||||
ms_type = tensor_to_ms_type[data_type]
|
||||
element_data = np.frombuffer(data, np_type)
|
||||
param_data_list.append(element_data)
|
||||
if (element_id == len(checkpoint_list.value) - 1) or \
|
||||
(element.tag != checkpoint_list.value[element_id + 1].tag):
|
||||
param_data = np.concatenate((param_data_list), axis=0)
|
||||
param_data_list.clear()
|
||||
dims = element.tensor.dims
|
||||
if dims == [0]:
|
||||
if 'Float' in data_type:
|
||||
param_data = float(param_data[0])
|
||||
elif 'Int' in data_type:
|
||||
param_data = int(param_data[0])
|
||||
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
||||
elif dims == [1]:
|
||||
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
||||
else:
|
||||
param_dim = []
|
||||
for dim in dims:
|
||||
param_dim.append(dim)
|
||||
param_value = param_data.reshape(param_dim)
|
||||
parameter = Parameter(Tensor(param_value, ms_type), name=element.tag)
|
||||
break
|
||||
|
||||
except BaseException as e:
|
||||
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
|
||||
raise RuntimeError(e.__str__())
|
||||
|
||||
if parameter is None:
|
||||
raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, "
|
||||
f"please check parameter name or checkpoint file.")
|
||||
return parameter
|
||||
|
|
Loading…
Reference in New Issue