forked from mindspore-Ecosystem/mindspore
fix getnext error in pynative
This commit is contained in:
parent
e7ea724386
commit
07b965dad9
|
@ -94,7 +94,8 @@ def connect_network_with_dataset(network, dataset_helper):
|
||||||
|
|
||||||
if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \
|
if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \
|
||||||
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \
|
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \
|
||||||
and context.get_context("device_target") == "Ascend":
|
and context.get_context("device_target") == "Ascend" \
|
||||||
|
and context.get_context("mode") == context.GRAPH_MODE:
|
||||||
|
|
||||||
if not hasattr(dataset, '__network__'):
|
if not hasattr(dataset, '__network__'):
|
||||||
dataset.__network__ = network
|
dataset.__network__ = network
|
||||||
|
@ -206,6 +207,7 @@ class DatasetHelper:
|
||||||
def get_data_info(self):
|
def get_data_info(self):
|
||||||
return self.iter.get_data_info()
|
return self.iter.get_data_info()
|
||||||
|
|
||||||
|
|
||||||
class _DatasetIter:
|
class _DatasetIter:
|
||||||
"""Base iter for dataset helper"""
|
"""Base iter for dataset helper"""
|
||||||
|
|
||||||
|
@ -286,6 +288,7 @@ class _DatasetIterGE(_DatasetIter):
|
||||||
|
|
||||||
self.op = op
|
self.op = op
|
||||||
|
|
||||||
|
|
||||||
class _DatasetIterPyNative(_DatasetIter):
|
class _DatasetIterPyNative(_DatasetIter):
|
||||||
"""Iter for MS(enable_loop_sink=False)."""
|
"""Iter for MS(enable_loop_sink=False)."""
|
||||||
|
|
||||||
|
@ -301,6 +304,7 @@ class _DatasetIterPyNative(_DatasetIter):
|
||||||
|
|
||||||
self.op = op
|
self.op = op
|
||||||
|
|
||||||
|
|
||||||
class _DatasetIterMSLoopSink(_DatasetIter):
|
class _DatasetIterMSLoopSink(_DatasetIter):
|
||||||
"""Iter for context (device_target=Ascend)"""
|
"""Iter for context (device_target=Ascend)"""
|
||||||
|
|
||||||
|
@ -354,6 +358,7 @@ class _DatasetIterPSLite(_DatasetIter):
|
||||||
|
|
||||||
class _DatasetIterNormal:
|
class _DatasetIterNormal:
|
||||||
"""Iter for normal(non sink) mode, feed the data from host."""
|
"""Iter for normal(non sink) mode, feed the data from host."""
|
||||||
|
|
||||||
def __init__(self, dataset, epoch_num=-1):
|
def __init__(self, dataset, epoch_num=-1):
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.device_num = _get_device_num()
|
self.device_num = _get_device_num()
|
||||||
|
|
Loading…
Reference in New Issue