fix dataset sink size error in pynative

This commit is contained in:
chujinjin 2021-04-21 18:28:00 +08:00
parent 2e43434221
commit a6db82aeee
2 changed files with 12 additions and 7 deletions

View File

@ -154,6 +154,16 @@ def connect_network_with_dataset(network, dataset_helper):
dataset.__me_inited__ = True
dataset_types, dataset_shapes = dataset_helper.types_shapes()
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name)
if hasattr(dataset_iter, "sink_size") and \
dataset_iter.sink_size == 1 and \
dataset.get_dataset_size() != 1 and \
hasattr(dataset_iter, "sink_count") and \
dataset_iter.sink_count == 1 and \
context.get_context("device_target") == "Ascend" and \
context.get_context("mode") == context.PYNATIVE_MODE:
dataset_helper.get_data_info()
return network

View File

@ -442,13 +442,8 @@ class Model:
if sink_size == -1:
epoch_num = epoch
else:
if is_graph:
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
train_dataset.__total_batch__ = epoch * sink_size
else:
sink_size = -1
epoch_num = epoch
logger.warning("Loop sink is not supported in PyNative mode, so it will be performed with no loop sink")
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
train_dataset.__total_batch__ = epoch * sink_size
cb_params.cur_step_num = 0
cb_params.dataset_sink_mode = True