diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 709c3cd326c..5fbc3913c16 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -93,8 +93,9 @@ def connect_network_with_dataset(network, dataset_helper): raise RuntimeError("Dataset should be connected with network only in sink mode.") if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ - and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ - and context.get_context("device_target") == "Ascend": + and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ + and context.get_context("device_target") == "Ascend" \ + and context.get_context("mode") == context.GRAPH_MODE: if not hasattr(dataset, '__network__'): dataset.__network__ = network @@ -206,6 +207,7 @@ class DatasetHelper: def get_data_info(self): return self.iter.get_data_info() + class _DatasetIter: """Base iter for dataset helper""" @@ -286,6 +288,7 @@ class _DatasetIterGE(_DatasetIter): self.op = op + class _DatasetIterPyNative(_DatasetIter): """Iter for MS(enable_loop_sink=False).""" @@ -301,6 +304,7 @@ class _DatasetIterPyNative(_DatasetIter): self.op = op + class _DatasetIterMSLoopSink(_DatasetIter): """Iter for context (device_target=Ascend)""" @@ -354,6 +358,7 @@ class _DatasetIterPSLite(_DatasetIter): class _DatasetIterNormal: """Iter for normal(non sink) mode, feed the data from host.""" + def __init__(self, dataset, epoch_num=-1): self.dataset = dataset self.device_num = _get_device_num()