!29114 1.6 Model support change dataset
Merge pull request !29114 from wangnan39/model_support_change_dataset
This commit is contained in:
commit
9978d6bd29
|
@ -136,6 +136,17 @@ def _check_add_offload(dataset, dataset_helper, network):
|
|||
return network
|
||||
|
||||
|
||||
class _DatasetAux:
|
||||
def __deepcopy__(self, memodict):
|
||||
return None
|
||||
|
||||
|
||||
def _get_dataset_aux(dataset):
|
||||
if not hasattr(dataset, '__network_aux__'):
|
||||
dataset.__network_aux__ = _DatasetAux()
|
||||
return dataset.__network_aux__
|
||||
|
||||
|
||||
def connect_network_with_dataset(network, dataset_helper):
|
||||
"""
|
||||
Connect the `network` with dataset in `dataset_helper`.
|
||||
|
@ -173,6 +184,7 @@ def connect_network_with_dataset(network, dataset_helper):
|
|||
"""
|
||||
dataset_iter = dataset_helper.iter
|
||||
dataset = dataset_iter.dataset
|
||||
aux = _get_dataset_aux(dataset)
|
||||
|
||||
if isinstance(dataset_iter, _DatasetIterNormal):
|
||||
raise RuntimeError("The API 'connect_network_with_dataset' should be called in dataset sink mode.")
|
||||
|
@ -180,35 +192,37 @@ def connect_network_with_dataset(network, dataset_helper):
|
|||
if _is_role_sched() or _is_role_pserver():
|
||||
return network
|
||||
|
||||
if not hasattr(aux, '__network__'):
|
||||
aux.__network__ = network
|
||||
|
||||
if aux.__network__ is not network:
|
||||
raise ValueError("The dataset has been connected to other network, please check the code.")
|
||||
|
||||
queue_name = dataset.__transfer_dataset__.queue_name
|
||||
if _dynamic_sink_scenario(dataset, dataset_iter):
|
||||
if not hasattr(dataset_iter, '__network__'):
|
||||
dataset_iter.__network__ = network
|
||||
network = dataset_iter.__network__
|
||||
|
||||
dataset_types, dataset_shapes = dataset_helper.get_data_info()
|
||||
dataset_types = [pytype_to_dtype(x) for x in dataset_types]
|
||||
|
||||
key = str(dataset_types) + str(dataset_shapes)
|
||||
if hasattr(dataset_iter, '__network_manage__') and key in dataset_iter.__network_manage__:
|
||||
network = dataset_iter.__network_manage__[key]
|
||||
if hasattr(aux, '__network_manage__') and key in aux.__network_manage__:
|
||||
network = aux.__network_manage__[key]
|
||||
else:
|
||||
if _need_to_full():
|
||||
device_num = _get_device_num() // _get_pipeline_stages()
|
||||
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
||||
|
||||
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name)
|
||||
dataset_iter.__network_manage__ = dataset_iter.__network_manage__ if hasattr(
|
||||
dataset_iter, '__network_manage__') else dict()
|
||||
dataset_iter.__network_manage__[key] = network
|
||||
aux.__network_manage__ = aux.__network_manage__ if hasattr(aux, '__network_manage__') else dict()
|
||||
aux.__network_manage__[key] = network
|
||||
return network
|
||||
|
||||
if not hasattr(dataset, '__me_inited__') and \
|
||||
not context.get_context("enable_ge") and \
|
||||
context.get_context("device_target") in ("Ascend", "GPU"):
|
||||
dataset.__me_inited__ = True
|
||||
network = _check_add_offload(dataset, dataset_helper, network)
|
||||
network = _generate_network_with_dataset(network, dataset_helper, queue_name)
|
||||
if hasattr(aux, '__sink_network__'):
|
||||
network = aux.__sink_network__
|
||||
else:
|
||||
if not context.get_context("enable_ge") and context.get_context("device_target") in ("Ascend", "GPU"):
|
||||
network = _check_add_offload(dataset, dataset_helper, network)
|
||||
network = _generate_network_with_dataset(network, dataset_helper, queue_name)
|
||||
aux.__sink_network__ = network
|
||||
|
||||
if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter):
|
||||
dataset_helper.get_data_info()
|
||||
|
|
|
@ -467,11 +467,10 @@ class Model:
|
|||
dataset_sink_mode=True,
|
||||
sink_size=sink_size)
|
||||
self._warmup_dataset(epoch, train_dataset, sink_size)
|
||||
self._train_network = train_network
|
||||
if context.get_auto_parallel_context("pipeline_stages") > 1 and valid_dataset:
|
||||
self._train_network.add_flags_recursive(is_first_iteration=True)
|
||||
train_network.add_flags_recursive(is_first_iteration=True)
|
||||
for inputs in train_dataset_helper:
|
||||
self._train_network.compile(*inputs)
|
||||
train_network.compile(*inputs)
|
||||
break
|
||||
|
||||
if valid_dataset:
|
||||
|
@ -483,11 +482,10 @@ class Model:
|
|||
valid_dataset_helper, eval_network = self._exec_preprocess(is_train=False,
|
||||
dataset=valid_dataset,
|
||||
dataset_sink_mode=True)
|
||||
self._eval_network = eval_network
|
||||
if context.get_auto_parallel_context("pipeline_stages") > 1:
|
||||
self._eval_network.add_flags_recursive(is_first_iteration=False)
|
||||
eval_network.add_flags_recursive(is_first_iteration=False)
|
||||
for inputs in valid_dataset_helper:
|
||||
self._eval_network.compile(*inputs)
|
||||
eval_network.compile(*inputs)
|
||||
break
|
||||
|
||||
@_save_final_ckpt
|
||||
|
@ -610,8 +608,7 @@ class Model:
|
|||
epoch_num=epoch_num,
|
||||
dataset_helper=dataset_helper)
|
||||
|
||||
self._train_network = train_network
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.train_network = train_network
|
||||
|
||||
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
||||
for inputs in dataset_helper:
|
||||
|
@ -622,7 +619,7 @@ class Model:
|
|||
self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
||||
cb_params.train_dataset_element = inputs
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._train_network(*inputs)
|
||||
outputs = train_network(*inputs)
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
if _is_role_pserver():
|
||||
|
@ -854,15 +851,14 @@ class Model:
|
|||
dataset_helper, eval_network = self._exec_preprocess(is_train=False,
|
||||
dataset=valid_dataset,
|
||||
dataset_sink_mode=True)
|
||||
self._eval_network = eval_network
|
||||
cb_params.eval_network = self._eval_network
|
||||
cb_params.eval_network = eval_network
|
||||
cb_params.dataset_sink_mode = True
|
||||
list_callback.begin(run_context)
|
||||
list_callback.epoch_begin(run_context)
|
||||
for inputs in dataset_helper:
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._eval_network(*inputs)
|
||||
outputs = eval_network(*inputs)
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
self._update_metrics(outputs)
|
||||
|
@ -1100,12 +1096,11 @@ class Model:
|
|||
dataset=train_dataset,
|
||||
dataset_sink_mode=dataset_sink_mode,
|
||||
sink_size=sink_size)
|
||||
self._train_network = train_network
|
||||
for inputs in train_dataset_helper:
|
||||
self._train_network.compile(*inputs)
|
||||
train_network.compile(*inputs)
|
||||
break
|
||||
train_dataset.__model_hash__ = hash(self)
|
||||
return self._train_network.parameter_layout_dict
|
||||
return train_network.parameter_layout_dict
|
||||
|
||||
def infer_predict_layout(self, *predict_data):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue