forked from mindspore-Ecosystem/mindspore
fix dataset sink size error in pynative
This commit is contained in:
parent
2e43434221
commit
a6db82aeee
|
@ -154,6 +154,16 @@ def connect_network_with_dataset(network, dataset_helper):
|
|||
dataset.__me_inited__ = True
|
||||
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
||||
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name)
|
||||
|
||||
if hasattr(dataset_iter, "sink_size") and \
|
||||
dataset_iter.sink_size == 1 and \
|
||||
dataset.get_dataset_size() != 1 and \
|
||||
hasattr(dataset_iter, "sink_count") and \
|
||||
dataset_iter.sink_count == 1 and \
|
||||
context.get_context("device_target") == "Ascend" and \
|
||||
context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
dataset_helper.get_data_info()
|
||||
|
||||
return network
|
||||
|
||||
|
||||
|
|
|
@ -442,13 +442,8 @@ class Model:
|
|||
if sink_size == -1:
|
||||
epoch_num = epoch
|
||||
else:
|
||||
if is_graph:
|
||||
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
|
||||
train_dataset.__total_batch__ = epoch * sink_size
|
||||
else:
|
||||
sink_size = -1
|
||||
epoch_num = epoch
|
||||
logger.warning("Loop sink is not supported in PyNative mode, so it will be performed with no loop sink")
|
||||
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
|
||||
train_dataset.__total_batch__ = epoch * sink_size
|
||||
|
||||
cb_params.cur_step_num = 0
|
||||
cb_params.dataset_sink_mode = True
|
||||
|
|
Loading…
Reference in New Issue