diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 7f85aba6942..204c271eac5 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -134,8 +134,8 @@ class Dataset: """ def __init__(self, num_parallel_workers=None): - self.input = [] - self.output = [] + self.children = [] + self.parent = [] self.num_parallel_workers = num_parallel_workers self._device_iter = 0 self._input_indexs = () @@ -1006,9 +1006,9 @@ class Dataset: dev_id = output_dataset.shard_id return "", dev_id - if not output_dataset.input: + if not output_dataset.children: raise RuntimeError("Unknown output_dataset: {}".format(type(output_dataset))) - input_dataset = output_dataset.input[0] + input_dataset = output_dataset.children[0] return get_distribution(input_dataset) distribution_path, device_id = get_distribution(self) @@ -1129,8 +1129,8 @@ class Dataset: Return: Number, number of batches. """ - if self.input: - return self.input[0].get_dataset_size() + if self.children: + return self.children[0].get_dataset_size() return None def num_classes(self): @@ -1140,23 +1140,23 @@ class Dataset: Return: Number, number of classes. """ - if self.input: - return self.input[0].num_classes() + if self.children: + return self.children[0].num_classes() return None def get_sync_notifiers(self): - if self.input: - return self.input[0].get_sync_notifiers() + if self.children: + return self.children[0].get_sync_notifiers() return {} def disable_sync(self): - if self.input: - return self.input[0].disable_sync() + if self.children: + return self.children[0].disable_sync() return {} def is_sync(self): - if self.input: - return self.input[0].is_sync() + if self.children: + return self.children[0].is_sync() return False def sync_update(self, condition_name, num_batch=None, data=None): @@ -1190,8 +1190,8 @@ class Dataset: Return: Number, the number of data in a batch. """ - if self.input: - return self.input[0].get_batch_size() + if self.children: + return self.children[0].get_batch_size() return 1 def get_repeat_count(self): @@ -1201,8 +1201,8 @@ class Dataset: Return: Number, the count of repeat. """ - if self.input: - return self.input[0].get_repeat_count() + if self.children: + return self.children[0].get_repeat_count() return 1 def get_class_indexing(self): @@ -1212,22 +1212,22 @@ class Dataset: Return: Dict, A str-to-int mapping from label name to index. """ - if self.input: - return self.input[0].get_class_indexing() + if self.children: + return self.children[0].get_class_indexing() raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self))) def reset(self): """Reset the dataset for next epoch.""" def is_shuffled(self): - for input_dataset in self.input: + for input_dataset in self.children: if input_dataset.is_shuffled(): return True return False def is_sharded(self): - for input_dataset in self.input: + for input_dataset in self.children: if input_dataset.is_sharded(): return True @@ -1466,8 +1466,8 @@ class BucketBatchByLengthDataset(DatasetOp): self.pad_to_bucket_boundary = pad_to_bucket_boundary self.drop_remainder = drop_remainder - self.input.append(input_dataset) - input_dataset.output.append(self) + self.children.append(input_dataset) + input_dataset.parent.append(self) self._input_indexs = input_dataset.input_indexs def get_args(self): @@ -1529,8 +1529,8 @@ class BatchDataset(DatasetOp): self.per_batch_map = per_batch_map self.input_columns = input_columns self.pad_info = pad_info - self.input.append(input_dataset) - input_dataset.output.append(self) + self.children.append(input_dataset) + input_dataset.parent.append(self) self._input_indexs = input_dataset.input_indexs def get_args(self): @@ -1549,7 +1549,7 @@ class BatchDataset(DatasetOp): Return: Number, number of batches. """ - child_size = self.input[0].get_dataset_size() + child_size = self.children[0].get_dataset_size() if child_size is not None: if self.drop_remainder: return math.floor(child_size / self.batch_size) @@ -1578,7 +1578,7 @@ class BatchDataset(DatasetOp): if isinstance(dataset, RepeatDataset): return True flag = False - for input_dataset in dataset.input: + for input_dataset in dataset.children: flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset) return flag @@ -1593,7 +1593,7 @@ class BatchDataset(DatasetOp): """ if isinstance(dataset, SyncWaitDataset): dataset.update_sync_batch_size(batch_size) - for input_dataset in dataset.input: + for input_dataset in dataset.children: BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) @@ -1699,21 +1699,21 @@ class SyncWaitDataset(DatasetOp): def __init__(self, input_dataset, condition_name, num_batch, callback=None): super().__init__() - self.input.append(input_dataset) - input_dataset.output.append(self) + self.children.append(input_dataset) + input_dataset.parent.append(self) # set to the default value, waiting for the batch to update it self._condition_name = condition_name if isinstance(num_batch, int) and num_batch <= 0: raise ValueError("num_batch need to be greater than 0.") self._pair = BlockReleasePair(num_batch, callback) - if self._condition_name in self.input[0].get_sync_notifiers(): + if self._condition_name in self.children[0].get_sync_notifiers(): raise RuntimeError("Condition name is already in use") logger.warning("Please remember to add dataset.sync_update(condition=%s), otherwise will result in hanging", condition_name) def get_sync_notifiers(self): - return {**self.input[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}} + return {**self.children[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}} def is_sync(self): return True @@ -1746,7 +1746,7 @@ class SyncWaitDataset(DatasetOp): if isinstance(dataset, BatchDataset): return True flag = False - for input_dataset in dataset.input: + for input_dataset in dataset.children: flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset) return flag @@ -1766,9 +1766,9 @@ class ShuffleDataset(DatasetOp): def __init__(self, input_dataset, buffer_size): super().__init__() self.buffer_size = buffer_size - self.input.append(input_dataset) + self.children.append(input_dataset) self.reshuffle_each_epoch = None - input_dataset.output.append(self) + input_dataset.parent.append(self) self._input_indexs = input_dataset.input_indexs if self.is_sync(): raise RuntimeError("No shuffle after sync operators") @@ -1864,7 +1864,7 @@ class MapDataset(DatasetOp): def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None, num_parallel_workers=None, python_multiprocessing=False): super().__init__(num_parallel_workers) - self.input.append(input_dataset) + self.children.append(input_dataset) if input_columns is not None and not isinstance(input_columns, list): input_columns = [input_columns] self.input_columns = input_columns @@ -1881,7 +1881,7 @@ class MapDataset(DatasetOp): and self.columns_order is None: raise ValueError("When (len(input_columns) != len(output_columns)), columns_order must be specified.") - input_dataset.output.append(self) + input_dataset.parent.append(self) self._input_indexs = input_dataset.input_indexs self.python_multiprocessing = python_multiprocessing self.process_pool = None @@ -1901,7 +1901,7 @@ class MapDataset(DatasetOp): Return: Number, number of batches. """ - return self.input[0].get_dataset_size() + return self.children[0].get_dataset_size() def __deepcopy__(self, memodict): if id(self) in memodict: @@ -1909,12 +1909,12 @@ class MapDataset(DatasetOp): cls = self.__class__ new_op = cls.__new__(cls) memodict[id(self)] = new_op - new_op.input = copy.deepcopy(self.input, memodict) + new_op.children = copy.deepcopy(self.children, memodict) new_op.input_columns = copy.deepcopy(self.input_columns, memodict) new_op.output_columns = copy.deepcopy(self.output_columns, memodict) new_op.columns_order = copy.deepcopy(self.columns_order, memodict) new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) - new_op.output = copy.deepcopy(self.output, memodict) + new_op.parent = copy.deepcopy(self.parent, memodict) new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) new_op.operations = self.operations @@ -1975,8 +1975,8 @@ class FilterDataset(DatasetOp): def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None): super().__init__(num_parallel_workers) self.predicate = lambda *args: bool(predicate(*args)) - self.input.append(input_dataset) - input_dataset.output.append(self) + self.children.append(input_dataset) + input_dataset.parent.append(self) if input_columns is not None and not isinstance(input_columns, list): input_columns = [input_columns] self.input_columns = input_columns @@ -2012,8 +2012,8 @@ class RepeatDataset(DatasetOp): self.count = -1 else: self.count = count - self.input.append(input_dataset) - input_dataset.output.append(self) + self.children.append(input_dataset) + input_dataset.parent.append(self) self._input_indexs = input_dataset.input_indexs def get_args(self): @@ -2028,7 +2028,7 @@ class RepeatDataset(DatasetOp): Return: Number, number of batches. """ - child_size = self.input[0].get_dataset_size() + child_size = self.children[0].get_dataset_size() if child_size is not None: return child_size return None @@ -2055,8 +2055,8 @@ class SkipDataset(DatasetOp): def __init__(self, input_dataset, count): super().__init__() self.count = count - self.input.append(input_dataset) - input_dataset.output.append(self) + self.children.append(input_dataset) + input_dataset.parent.append(self) self._input_indexs = input_dataset.input_indexs def get_args(self): @@ -2071,7 +2071,7 @@ class SkipDataset(DatasetOp): Return: Number, number of batches. """ - child_size = self.input[0].get_dataset_size() + child_size = self.children[0].get_dataset_size() output_size = 0 if self.count >= 0 and self.count < child_size: output_size = child_size - self.count @@ -2090,8 +2090,8 @@ class TakeDataset(DatasetOp): def __init__(self, input_dataset, count): super().__init__() self.count = count - self.input.append(input_dataset) - input_dataset.output.append(self) + self.children.append(input_dataset) + input_dataset.parent.append(self) self._input_indexs = input_dataset.input_indexs def get_args(self): @@ -2106,7 +2106,7 @@ class TakeDataset(DatasetOp): Return: Number, number of batches. """ - child_size = self.input[0].get_dataset_size() + child_size = self.children[0].get_dataset_size() if child_size < self.count: return child_size return self.count @@ -2130,8 +2130,8 @@ class ZipDataset(DatasetOp): raise TypeError("The parameter %s of zip has type error!" % (dataset)) self.datasets = datasets for data in datasets: - self.input.append(data) - data.output.append(self) + self.children.append(data) + data.parent.append(self) def get_dataset_size(self): """ @@ -2140,7 +2140,7 @@ class ZipDataset(DatasetOp): Return: Number, number of batches. """ - children_sizes = [c.get_dataset_size() for c in self.input] + children_sizes = [c.get_dataset_size() for c in self.children] if all(c is not None for c in children_sizes): return min(children_sizes) return None @@ -2155,7 +2155,7 @@ class ZipDataset(DatasetOp): return None def is_sync(self): - return any([c.is_sync() for c in self.input]) + return any([c.is_sync() for c in self.children]) def get_args(self): args = super().get_args() @@ -2180,8 +2180,8 @@ class ConcatDataset(DatasetOp): raise TypeError("The parameter %s of concat has type error!" % (dataset)) self.datasets = datasets for data in datasets: - self.input.append(data) - data.output.append(self) + self.children.append(data) + data.parent.append(self) def get_dataset_size(self): """ @@ -2190,7 +2190,7 @@ class ConcatDataset(DatasetOp): Return: Number, number of batches. """ - children_sizes = [c.get_dataset_size() for c in self.input] + children_sizes = [c.get_dataset_size() for c in self.children] dataset_size = sum(children_sizes) return dataset_size @@ -2213,8 +2213,8 @@ class RenameDataset(DatasetOp): output_columns = [output_columns] self.input_column_names = input_columns self.output_column_names = output_columns - self.input.append(input_dataset) - input_dataset.output.append(self) + self.children.append(input_dataset) + input_dataset.parent.append(self) self._input_indexs = input_dataset.input_indexs def get_args(self): @@ -2240,10 +2240,10 @@ class ProjectDataset(DatasetOp): if not isinstance(columns, list): columns = [columns] self.columns = columns - self.input.append(input_dataset) + self.children.append(input_dataset) self.prefetch_size = prefetch_size - input_dataset.output.append(self) + input_dataset.parent.append(self) self._input_indexs = input_dataset.input_indexs def get_args(self): @@ -2267,8 +2267,8 @@ class TransferDataset(DatasetOp): def __init__(self, input_dataset, queue_name, device_id, device_type, num_batch=None): super().__init__() - self.input.append(input_dataset) - input_dataset.output.append(self) + self.children.append(input_dataset) + input_dataset.parent.append(self) self.queue_name = queue_name self._input_indexs = input_dataset.input_indexs self._device_type = device_type @@ -3170,8 +3170,8 @@ class GeneratorDataset(MappableDataset): cls = self.__class__ new_op = cls.__new__(cls) memodict[id(self)] = new_op - new_op.input = copy.deepcopy(self.input, memodict) - new_op.output = copy.deepcopy(self.output, memodict) + new_op.children = copy.deepcopy(self.children, memodict) + new_op.parent = copy.deepcopy(self.parent, memodict) new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) new_op.column_types = copy.deepcopy(self.column_types, memodict) new_op.column_names = copy.deepcopy(self.column_names, memodict) @@ -4879,14 +4879,14 @@ class BuildVocabDataset(DatasetOp): prefetch_size=None): super().__init__() self.columns = columns - self.input.append(input_dataset) + self.children.append(input_dataset) self.prefetch_size = prefetch_size self.vocab = vocab self.freq_range = freq_range self.top_k = top_k self.special_tokens = special_tokens self.special_first = special_first - input_dataset.output.append(self) + input_dataset.parent.append(self) def get_args(self): args = super().get_args() @@ -4905,11 +4905,11 @@ class BuildVocabDataset(DatasetOp): cls = self.__class__ new_op = cls.__new__(cls) memodict[id(self)] = new_op - new_op.input = copy.deepcopy(self.input, memodict) + new_op.children = copy.deepcopy(self.children, memodict) new_op.columns = copy.deepcopy(self.columns, memodict) new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) new_op.prefetch_size = copy.deepcopy(self.prefetch_size, memodict) - new_op.output = copy.deepcopy(self.output, memodict) + new_op.parent = copy.deepcopy(self.parent, memodict) new_op.freq_range = copy.deepcopy(self.freq_range, memodict) new_op.top_k = copy.deepcopy(self.top_k, memodict) new_op.vocab = self.vocab diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 89d8165b1e8..4946fb32527 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -38,13 +38,13 @@ def _cleanup(): def alter_tree(node): """Traversing the python Dataset tree/graph to perform some alteration to some specific nodes.""" - if not node.input: + if not node.children: return _alter_node(node) converted_children = [] - for input_op in node.input: + for input_op in node.children: converted_children.append(alter_tree(input_op)) - node.input = converted_children + node.children = converted_children return _alter_node(node) @@ -86,14 +86,14 @@ class Iterator: def __is_tree_node(self, node): """Check if a node is tree node.""" - if not node.input: - if len(node.output) > 1: + if not node.children: + if len(node.parent) > 1: return False - if len(node.output) > 1: + if len(node.parent) > 1: return False - for input_node in node.input: + for input_node in node.children: cls = self.__is_tree_node(input_node) if not cls: return False @@ -174,7 +174,7 @@ class Iterator: op_type = self.__get_dataset_type(node) c_node = self.depipeline.AddNodeToTree(op_type, node.get_args()) - for py_child in node.input: + for py_child in node.children: c_child = self.__convert_node_postorder(py_child) self.depipeline.AddChildToParentNode(c_child, c_node) @@ -184,7 +184,7 @@ class Iterator: """Recursively get batch node in the dataset tree.""" if isinstance(dataset, de.BatchDataset): return - for input_op in dataset.input: + for input_op in dataset.children: self.__batch_node(input_op, level + 1) @staticmethod @@ -194,11 +194,11 @@ class Iterator: ptr = hex(id(dataset)) for _ in range(level): logger.info("\t", end='') - if not dataset.input: + if not dataset.children: logger.info("-%s (%s)", name, ptr) else: logger.info("+%s (%s)", name, ptr) - for input_op in dataset.input: + for input_op in dataset.children: Iterator.__print_local(input_op, level + 1) def print(self): diff --git a/mindspore/dataset/engine/serializer_deserializer.py b/mindspore/dataset/engine/serializer_deserializer.py index d0399d6bea5..833f660f167 100644 --- a/mindspore/dataset/engine/serializer_deserializer.py +++ b/mindspore/dataset/engine/serializer_deserializer.py @@ -182,11 +182,11 @@ def traverse(node): node_repr['shard_id'] = None # Leaf node doesn't have input attribute. - if not node.input: + if not node.children: return node_repr # Recursively traverse the child and assign it to the current node_repr['children']. - for child in node.input: + for child in node.children: node_repr["children"].append(traverse(child)) return node_repr @@ -226,11 +226,11 @@ def construct_pipeline(node): # Instantiate python Dataset object based on the current dictionary element dataset = create_node(node) # Initially it is not connected to any other object. - dataset.input = [] + dataset.children = [] # Construct the children too and add edge between the children and parent. for child in node['children']: - dataset.input.append(construct_pipeline(child)) + dataset.children.append(construct_pipeline(child)) return dataset diff --git a/tests/ut/python/dataset/test_iterator.py b/tests/ut/python/dataset/test_iterator.py index 0b896e8d319..af5a66e89e7 100644 --- a/tests/ut/python/dataset/test_iterator.py +++ b/tests/ut/python/dataset/test_iterator.py @@ -103,7 +103,7 @@ def test_tree_copy(): itr = data1.create_tuple_iterator() assert id(data1) != id(itr.dataset) - assert id(data) != id(itr.dataset.input[0]) + assert id(data) != id(itr.dataset.children[0]) assert id(data1.operations[0]) == id(itr.dataset.operations[0]) itr.release()