forked from mindspore-Ecosystem/mindspore
!10255 using do_copy option to choose Tensor.from_numpy to do Tensor convert
From: @ms_yan Reviewed-by: Signed-off-by:
This commit is contained in:
commit
1343c9cca2
|
@ -1255,7 +1255,7 @@ class Dataset:
|
|||
del api_tree
|
||||
|
||||
@check_tuple_iterator
|
||||
def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False):
|
||||
def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True):
|
||||
"""
|
||||
Create an iterator over the dataset. The data retrieved will be a list of ndarrays of data.
|
||||
|
||||
|
@ -1269,6 +1269,8 @@ class Dataset:
|
|||
(default=-1, iterator can be iterated infinite number of epochs)
|
||||
output_numpy (bool, optional): Whether or not to output NumPy datatype.
|
||||
If output_numpy=False, iterator will output MSTensor (default=False).
|
||||
do_copy (bool, optional): when output data type is mindspore.Tensor,
|
||||
use this param to select the conversion method, only take False for better performance (default=True).
|
||||
|
||||
Returns:
|
||||
Iterator, list of ndarrays.
|
||||
|
@ -1290,7 +1292,7 @@ class Dataset:
|
|||
|
||||
if Dataset._noop_mode():
|
||||
return DummyIterator(self, 'tuple')
|
||||
return TupleIterator(self, columns, num_epochs, output_numpy)
|
||||
return TupleIterator(self, columns, num_epochs, output_numpy, do_copy)
|
||||
|
||||
@check_dict_iterator
|
||||
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
|
||||
|
@ -2788,7 +2790,7 @@ class TransferDataset(Dataset):
|
|||
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
|
||||
raise RuntimeError("TransferDataset is not iterable.")
|
||||
|
||||
def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False):
|
||||
def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False, do_copy=True):
|
||||
raise RuntimeError("TransferDataset is not iterable.")
|
||||
|
||||
def __iter__(self):
|
||||
|
|
|
@ -63,7 +63,7 @@ class Iterator:
|
|||
dataset: Dataset to be iterated over
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_epochs=-1, output_numpy=False):
|
||||
def __init__(self, dataset, num_epochs=-1, output_numpy=False, do_copy=True):
|
||||
self._col_names = None
|
||||
|
||||
# create a copy of tree and work on it.
|
||||
|
@ -80,7 +80,10 @@ class Iterator:
|
|||
|
||||
self._transform_tensor = lambda t: t.as_array()
|
||||
if not output_numpy:
|
||||
self._transform_tensor = lambda t: Tensor(t.as_array())
|
||||
if do_copy:
|
||||
self._transform_tensor = lambda t: Tensor(t.as_array())
|
||||
else:
|
||||
self._transform_tensor = lambda t: Tensor.from_numpy(t.as_array())
|
||||
self._index = 0
|
||||
|
||||
# todo remove next when ContextManager is done
|
||||
|
@ -179,13 +182,13 @@ class TupleIterator(Iterator):
|
|||
The derived class of Iterator with list type.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False):
|
||||
def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False, do_copy=True):
|
||||
if columns is not None:
|
||||
if not isinstance(columns, list):
|
||||
columns = [columns]
|
||||
# todo: move next to IR
|
||||
dataset = dataset.project(columns)
|
||||
super().__init__(dataset, num_epochs, output_numpy)
|
||||
super().__init__(dataset, num_epochs, output_numpy, do_copy)
|
||||
|
||||
def _get_next(self):
|
||||
"""
|
||||
|
|
|
@ -298,7 +298,7 @@ def check_tuple_iterator(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[columns, num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs)
|
||||
[columns, num_epochs, _, _], param_dict = parse_user_args(method, *args, **kwargs)
|
||||
nreq_param_bool = ['output_numpy']
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
if num_epochs is not None:
|
||||
|
|
|
@ -394,7 +394,7 @@ class _DatasetIterNormal:
|
|||
self.dataset = dataset
|
||||
self.device_num = _get_device_num()
|
||||
self.global_rank = _get_global_rank()
|
||||
self.iter = self.dataset.create_tuple_iterator(num_epochs=epoch_num)
|
||||
self.iter = self.dataset.create_tuple_iterator(num_epochs=epoch_num, do_copy=False)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
|
|
@ -55,7 +55,7 @@ class MindData:
|
|||
self.send_epoch_end = send_epoch_end
|
||||
return self
|
||||
|
||||
def create_tuple_iterator(self, num_epochs=-1):
|
||||
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
|
||||
return self.__iter__()
|
||||
|
||||
def send(self, num_epochs=-1):
|
||||
|
|
|
@ -125,7 +125,7 @@ class FakeData:
|
|||
def set_label_onehot(self, is_onehot=True):
|
||||
self.is_onehot = is_onehot
|
||||
|
||||
def create_tuple_iterator(self, num_epochs=-1):
|
||||
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
|
||||
_ = num_epochs
|
||||
return self
|
||||
|
||||
|
|
|
@ -128,7 +128,7 @@ class FakeData:
|
|||
def set_label_onehot(self, is_onehot=True):
|
||||
self.is_onehot = is_onehot
|
||||
|
||||
def create_tuple_iterator(self, num_epochs=-1):
|
||||
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
|
||||
_ = num_epochs
|
||||
return self
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ class MindData:
|
|||
def output_shapes(self):
|
||||
return self._output_shapes
|
||||
|
||||
def create_tuple_iterator(self, num_epochs=-1):
|
||||
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
|
||||
return self
|
||||
|
||||
@property
|
||||
|
|
|
@ -152,7 +152,7 @@ class DatasetLenet():
|
|||
def get_repeat_count(self):
|
||||
return 1
|
||||
|
||||
def create_tuple_iterator(self, num_epochs=-1):
|
||||
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
|
||||
return self
|
||||
|
||||
def test_double_subgraphs_train():
|
||||
|
|
|
@ -275,7 +275,7 @@ class DatasetLenet():
|
|||
def get_repeat_count(self):
|
||||
return 1
|
||||
|
||||
def create_tuple_iterator(self, num_epochs=-1):
|
||||
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
|
||||
return self
|
||||
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ class DatasetLenet():
|
|||
def get_repeat_count(self):
|
||||
return 1
|
||||
|
||||
def create_tuple_iterator(self, num_epochs=-1):
|
||||
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
|
||||
return self
|
||||
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ class Dataset():
|
|||
def get_repeat_count(self):
|
||||
return 1
|
||||
|
||||
def create_tuple_iterator(self, num_epochs=-1):
|
||||
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
|
||||
return self
|
||||
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ class DatasetLenet():
|
|||
def get_batch_size(self):
|
||||
return 32
|
||||
|
||||
def create_tuple_iterator(self, num_epochs=1):
|
||||
def create_tuple_iterator(self, num_epochs=1, do_copy=True):
|
||||
return self
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue