!29114 1.6 Model support change dataset

Merge pull request !29114 from wangnan39/model_support_change_dataset
This commit is contained in:
i-robot 2022-01-19 01:50:30 +00:00 committed by Gitee
commit 9978d6bd29
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 39 additions and 30 deletions

View File

@ -136,6 +136,17 @@ def _check_add_offload(dataset, dataset_helper, network):
return network
class _DatasetAux:
def __deepcopy__(self, memodict):
return None
def _get_dataset_aux(dataset):
if not hasattr(dataset, '__network_aux__'):
dataset.__network_aux__ = _DatasetAux()
return dataset.__network_aux__
def connect_network_with_dataset(network, dataset_helper):
"""
Connect the `network` with dataset in `dataset_helper`.
@ -173,6 +184,7 @@ def connect_network_with_dataset(network, dataset_helper):
"""
dataset_iter = dataset_helper.iter
dataset = dataset_iter.dataset
aux = _get_dataset_aux(dataset)
if isinstance(dataset_iter, _DatasetIterNormal):
raise RuntimeError("The API 'connect_network_with_dataset' should be called in dataset sink mode.")
@ -180,35 +192,37 @@ def connect_network_with_dataset(network, dataset_helper):
if _is_role_sched() or _is_role_pserver():
return network
if not hasattr(aux, '__network__'):
aux.__network__ = network
if aux.__network__ is not network:
raise ValueError("The dataset has been connected to other network, please check the code.")
queue_name = dataset.__transfer_dataset__.queue_name
if _dynamic_sink_scenario(dataset, dataset_iter):
if not hasattr(dataset_iter, '__network__'):
dataset_iter.__network__ = network
network = dataset_iter.__network__
dataset_types, dataset_shapes = dataset_helper.get_data_info()
dataset_types = [pytype_to_dtype(x) for x in dataset_types]
key = str(dataset_types) + str(dataset_shapes)
if hasattr(dataset_iter, '__network_manage__') and key in dataset_iter.__network_manage__:
network = dataset_iter.__network_manage__[key]
if hasattr(aux, '__network_manage__') and key in aux.__network_manage__:
network = aux.__network_manage__[key]
else:
if _need_to_full():
device_num = _get_device_num() // _get_pipeline_stages()
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name)
dataset_iter.__network_manage__ = dataset_iter.__network_manage__ if hasattr(
dataset_iter, '__network_manage__') else dict()
dataset_iter.__network_manage__[key] = network
aux.__network_manage__ = aux.__network_manage__ if hasattr(aux, '__network_manage__') else dict()
aux.__network_manage__[key] = network
return network
if not hasattr(dataset, '__me_inited__') and \
not context.get_context("enable_ge") and \
context.get_context("device_target") in ("Ascend", "GPU"):
dataset.__me_inited__ = True
network = _check_add_offload(dataset, dataset_helper, network)
network = _generate_network_with_dataset(network, dataset_helper, queue_name)
if hasattr(aux, '__sink_network__'):
network = aux.__sink_network__
else:
if not context.get_context("enable_ge") and context.get_context("device_target") in ("Ascend", "GPU"):
network = _check_add_offload(dataset, dataset_helper, network)
network = _generate_network_with_dataset(network, dataset_helper, queue_name)
aux.__sink_network__ = network
if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter):
dataset_helper.get_data_info()

View File

@ -467,11 +467,10 @@ class Model:
dataset_sink_mode=True,
sink_size=sink_size)
self._warmup_dataset(epoch, train_dataset, sink_size)
self._train_network = train_network
if context.get_auto_parallel_context("pipeline_stages") > 1 and valid_dataset:
self._train_network.add_flags_recursive(is_first_iteration=True)
train_network.add_flags_recursive(is_first_iteration=True)
for inputs in train_dataset_helper:
self._train_network.compile(*inputs)
train_network.compile(*inputs)
break
if valid_dataset:
@ -483,11 +482,10 @@ class Model:
valid_dataset_helper, eval_network = self._exec_preprocess(is_train=False,
dataset=valid_dataset,
dataset_sink_mode=True)
self._eval_network = eval_network
if context.get_auto_parallel_context("pipeline_stages") > 1:
self._eval_network.add_flags_recursive(is_first_iteration=False)
eval_network.add_flags_recursive(is_first_iteration=False)
for inputs in valid_dataset_helper:
self._eval_network.compile(*inputs)
eval_network.compile(*inputs)
break
@_save_final_ckpt
@ -610,8 +608,7 @@ class Model:
epoch_num=epoch_num,
dataset_helper=dataset_helper)
self._train_network = train_network
cb_params.train_network = self._train_network
cb_params.train_network = train_network
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper:
@ -622,7 +619,7 @@ class Model:
self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
cb_params.train_dataset_element = inputs
list_callback.step_begin(run_context)
outputs = self._train_network(*inputs)
outputs = train_network(*inputs)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
if _is_role_pserver():
@ -854,15 +851,14 @@ class Model:
dataset_helper, eval_network = self._exec_preprocess(is_train=False,
dataset=valid_dataset,
dataset_sink_mode=True)
self._eval_network = eval_network
cb_params.eval_network = self._eval_network
cb_params.eval_network = eval_network
cb_params.dataset_sink_mode = True
list_callback.begin(run_context)
list_callback.epoch_begin(run_context)
for inputs in dataset_helper:
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
outputs = self._eval_network(*inputs)
outputs = eval_network(*inputs)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
self._update_metrics(outputs)
@ -1100,12 +1096,11 @@ class Model:
dataset=train_dataset,
dataset_sink_mode=dataset_sink_mode,
sink_size=sink_size)
self._train_network = train_network
for inputs in train_dataset_helper:
self._train_network.compile(*inputs)
train_network.compile(*inputs)
break
train_dataset.__model_hash__ = hash(self)
return self._train_network.parameter_layout_dict
return train_network.parameter_layout_dict
def infer_predict_layout(self, *predict_data):
"""