forked from mindspore-Ecosystem/mindspore
!2404 [Dataset]rename input variable name to children and output variable to parent
Merge pull request !2404 from xulei/input_output
This commit is contained in:
commit
a9c309da4c
|
@ -134,8 +134,8 @@ class Dataset:
|
|||
"""
|
||||
|
||||
def __init__(self, num_parallel_workers=None):
|
||||
self.input = []
|
||||
self.output = []
|
||||
self.children = []
|
||||
self.parent = []
|
||||
self.num_parallel_workers = num_parallel_workers
|
||||
self._device_iter = 0
|
||||
self._input_indexs = ()
|
||||
|
@ -1007,9 +1007,9 @@ class Dataset:
|
|||
dev_id = output_dataset.shard_id
|
||||
return "", dev_id
|
||||
|
||||
if not output_dataset.input:
|
||||
if not output_dataset.children:
|
||||
raise RuntimeError("Unknown output_dataset: {}".format(type(output_dataset)))
|
||||
input_dataset = output_dataset.input[0]
|
||||
input_dataset = output_dataset.children[0]
|
||||
return get_distribution(input_dataset)
|
||||
|
||||
distribution_path, device_id = get_distribution(self)
|
||||
|
@ -1130,8 +1130,8 @@ class Dataset:
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self.input:
|
||||
return self.input[0].get_dataset_size()
|
||||
if self.children:
|
||||
return self.children[0].get_dataset_size()
|
||||
return None
|
||||
|
||||
def num_classes(self):
|
||||
|
@ -1141,23 +1141,23 @@ class Dataset:
|
|||
Return:
|
||||
Number, number of classes.
|
||||
"""
|
||||
if self.input:
|
||||
return self.input[0].num_classes()
|
||||
if self.children:
|
||||
return self.children[0].num_classes()
|
||||
return None
|
||||
|
||||
def get_sync_notifiers(self):
|
||||
if self.input:
|
||||
return self.input[0].get_sync_notifiers()
|
||||
if self.children:
|
||||
return self.children[0].get_sync_notifiers()
|
||||
return {}
|
||||
|
||||
def disable_sync(self):
|
||||
if self.input:
|
||||
return self.input[0].disable_sync()
|
||||
if self.children:
|
||||
return self.children[0].disable_sync()
|
||||
return {}
|
||||
|
||||
def is_sync(self):
|
||||
if self.input:
|
||||
return self.input[0].is_sync()
|
||||
if self.children:
|
||||
return self.children[0].is_sync()
|
||||
return False
|
||||
|
||||
def sync_update(self, condition_name, num_batch=None, data=None):
|
||||
|
@ -1191,8 +1191,8 @@ class Dataset:
|
|||
Return:
|
||||
Number, the number of data in a batch.
|
||||
"""
|
||||
if self.input:
|
||||
return self.input[0].get_batch_size()
|
||||
if self.children:
|
||||
return self.children[0].get_batch_size()
|
||||
return 1
|
||||
|
||||
def get_repeat_count(self):
|
||||
|
@ -1202,8 +1202,8 @@ class Dataset:
|
|||
Return:
|
||||
Number, the count of repeat.
|
||||
"""
|
||||
if self.input:
|
||||
return self.input[0].get_repeat_count()
|
||||
if self.children:
|
||||
return self.children[0].get_repeat_count()
|
||||
return 1
|
||||
|
||||
def get_class_indexing(self):
|
||||
|
@ -1213,22 +1213,22 @@ class Dataset:
|
|||
Return:
|
||||
Dict, A str-to-int mapping from label name to index.
|
||||
"""
|
||||
if self.input:
|
||||
return self.input[0].get_class_indexing()
|
||||
if self.children:
|
||||
return self.children[0].get_class_indexing()
|
||||
raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self)))
|
||||
|
||||
def reset(self):
|
||||
"""Reset the dataset for next epoch."""
|
||||
|
||||
def is_shuffled(self):
|
||||
for input_dataset in self.input:
|
||||
for input_dataset in self.children:
|
||||
if input_dataset.is_shuffled():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_sharded(self):
|
||||
for input_dataset in self.input:
|
||||
for input_dataset in self.children:
|
||||
if input_dataset.is_sharded():
|
||||
return True
|
||||
|
||||
|
@ -1467,8 +1467,8 @@ class BucketBatchByLengthDataset(DatasetOp):
|
|||
self.pad_to_bucket_boundary = pad_to_bucket_boundary
|
||||
self.drop_remainder = drop_remainder
|
||||
|
||||
self.input.append(input_dataset)
|
||||
input_dataset.output.append(self)
|
||||
self.children.append(input_dataset)
|
||||
input_dataset.parent.append(self)
|
||||
self._input_indexs = input_dataset.input_indexs
|
||||
|
||||
def get_args(self):
|
||||
|
@ -1530,8 +1530,8 @@ class BatchDataset(DatasetOp):
|
|||
self.per_batch_map = per_batch_map
|
||||
self.input_columns = input_columns
|
||||
self.pad_info = pad_info
|
||||
self.input.append(input_dataset)
|
||||
input_dataset.output.append(self)
|
||||
self.children.append(input_dataset)
|
||||
input_dataset.parent.append(self)
|
||||
self._input_indexs = input_dataset.input_indexs
|
||||
|
||||
def get_args(self):
|
||||
|
@ -1550,7 +1550,7 @@ class BatchDataset(DatasetOp):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
child_size = self.input[0].get_dataset_size()
|
||||
child_size = self.children[0].get_dataset_size()
|
||||
if child_size is not None:
|
||||
if self.drop_remainder:
|
||||
return math.floor(child_size / self.batch_size)
|
||||
|
@ -1579,7 +1579,7 @@ class BatchDataset(DatasetOp):
|
|||
if isinstance(dataset, RepeatDataset):
|
||||
return True
|
||||
flag = False
|
||||
for input_dataset in dataset.input:
|
||||
for input_dataset in dataset.children:
|
||||
flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset)
|
||||
return flag
|
||||
|
||||
|
@ -1594,7 +1594,7 @@ class BatchDataset(DatasetOp):
|
|||
"""
|
||||
if isinstance(dataset, SyncWaitDataset):
|
||||
dataset.update_sync_batch_size(batch_size)
|
||||
for input_dataset in dataset.input:
|
||||
for input_dataset in dataset.children:
|
||||
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
|
||||
|
||||
|
||||
|
@ -1700,21 +1700,21 @@ class SyncWaitDataset(DatasetOp):
|
|||
|
||||
def __init__(self, input_dataset, condition_name, num_batch, callback=None):
|
||||
super().__init__()
|
||||
self.input.append(input_dataset)
|
||||
input_dataset.output.append(self)
|
||||
self.children.append(input_dataset)
|
||||
input_dataset.parent.append(self)
|
||||
# set to the default value, waiting for the batch to update it
|
||||
self._condition_name = condition_name
|
||||
if isinstance(num_batch, int) and num_batch <= 0:
|
||||
raise ValueError("num_batch need to be greater than 0.")
|
||||
|
||||
self._pair = BlockReleasePair(num_batch, callback)
|
||||
if self._condition_name in self.input[0].get_sync_notifiers():
|
||||
if self._condition_name in self.children[0].get_sync_notifiers():
|
||||
raise RuntimeError("Condition name is already in use")
|
||||
logger.warning("Please remember to add dataset.sync_update(condition=%s), otherwise will result in hanging",
|
||||
condition_name)
|
||||
|
||||
def get_sync_notifiers(self):
|
||||
return {**self.input[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}}
|
||||
return {**self.children[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}}
|
||||
|
||||
def is_sync(self):
|
||||
return True
|
||||
|
@ -1747,7 +1747,7 @@ class SyncWaitDataset(DatasetOp):
|
|||
if isinstance(dataset, BatchDataset):
|
||||
return True
|
||||
flag = False
|
||||
for input_dataset in dataset.input:
|
||||
for input_dataset in dataset.children:
|
||||
flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset)
|
||||
return flag
|
||||
|
||||
|
@ -1767,9 +1767,9 @@ class ShuffleDataset(DatasetOp):
|
|||
def __init__(self, input_dataset, buffer_size):
|
||||
super().__init__()
|
||||
self.buffer_size = buffer_size
|
||||
self.input.append(input_dataset)
|
||||
self.children.append(input_dataset)
|
||||
self.reshuffle_each_epoch = None
|
||||
input_dataset.output.append(self)
|
||||
input_dataset.parent.append(self)
|
||||
self._input_indexs = input_dataset.input_indexs
|
||||
if self.is_sync():
|
||||
raise RuntimeError("No shuffle after sync operators")
|
||||
|
@ -1865,7 +1865,7 @@ class MapDataset(DatasetOp):
|
|||
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
||||
num_parallel_workers=None, python_multiprocessing=False):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.input.append(input_dataset)
|
||||
self.children.append(input_dataset)
|
||||
if input_columns is not None and not isinstance(input_columns, list):
|
||||
input_columns = [input_columns]
|
||||
self.input_columns = input_columns
|
||||
|
@ -1882,7 +1882,7 @@ class MapDataset(DatasetOp):
|
|||
and self.columns_order is None:
|
||||
raise ValueError("When (len(input_columns) != len(output_columns)), columns_order must be specified.")
|
||||
|
||||
input_dataset.output.append(self)
|
||||
input_dataset.parent.append(self)
|
||||
self._input_indexs = input_dataset.input_indexs
|
||||
self.python_multiprocessing = python_multiprocessing
|
||||
self.process_pool = None
|
||||
|
@ -1902,7 +1902,7 @@ class MapDataset(DatasetOp):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
return self.input[0].get_dataset_size()
|
||||
return self.children[0].get_dataset_size()
|
||||
|
||||
def __deepcopy__(self, memodict):
|
||||
if id(self) in memodict:
|
||||
|
@ -1910,12 +1910,12 @@ class MapDataset(DatasetOp):
|
|||
cls = self.__class__
|
||||
new_op = cls.__new__(cls)
|
||||
memodict[id(self)] = new_op
|
||||
new_op.input = copy.deepcopy(self.input, memodict)
|
||||
new_op.children = copy.deepcopy(self.children, memodict)
|
||||
new_op.input_columns = copy.deepcopy(self.input_columns, memodict)
|
||||
new_op.output_columns = copy.deepcopy(self.output_columns, memodict)
|
||||
new_op.columns_order = copy.deepcopy(self.columns_order, memodict)
|
||||
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
|
||||
new_op.output = copy.deepcopy(self.output, memodict)
|
||||
new_op.parent = copy.deepcopy(self.parent, memodict)
|
||||
new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict)
|
||||
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
|
||||
new_op.operations = self.operations
|
||||
|
@ -1976,8 +1976,8 @@ class FilterDataset(DatasetOp):
|
|||
def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.predicate = lambda *args: bool(predicate(*args))
|
||||
self.input.append(input_dataset)
|
||||
input_dataset.output.append(self)
|
||||
self.children.append(input_dataset)
|
||||
input_dataset.parent.append(self)
|
||||
if input_columns is not None and not isinstance(input_columns, list):
|
||||
input_columns = [input_columns]
|
||||
self.input_columns = input_columns
|
||||
|
@ -2013,8 +2013,8 @@ class RepeatDataset(DatasetOp):
|
|||
self.count = -1
|
||||
else:
|
||||
self.count = count
|
||||
self.input.append(input_dataset)
|
||||
input_dataset.output.append(self)
|
||||
self.children.append(input_dataset)
|
||||
input_dataset.parent.append(self)
|
||||
self._input_indexs = input_dataset.input_indexs
|
||||
|
||||
def get_args(self):
|
||||
|
@ -2029,7 +2029,7 @@ class RepeatDataset(DatasetOp):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
child_size = self.input[0].get_dataset_size()
|
||||
child_size = self.children[0].get_dataset_size()
|
||||
if child_size is not None:
|
||||
return child_size
|
||||
return None
|
||||
|
@ -2056,8 +2056,8 @@ class SkipDataset(DatasetOp):
|
|||
def __init__(self, input_dataset, count):
|
||||
super().__init__()
|
||||
self.count = count
|
||||
self.input.append(input_dataset)
|
||||
input_dataset.output.append(self)
|
||||
self.children.append(input_dataset)
|
||||
input_dataset.parent.append(self)
|
||||
self._input_indexs = input_dataset.input_indexs
|
||||
|
||||
def get_args(self):
|
||||
|
@ -2072,7 +2072,7 @@ class SkipDataset(DatasetOp):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
child_size = self.input[0].get_dataset_size()
|
||||
child_size = self.children[0].get_dataset_size()
|
||||
output_size = 0
|
||||
if self.count >= 0 and self.count < child_size:
|
||||
output_size = child_size - self.count
|
||||
|
@ -2091,8 +2091,8 @@ class TakeDataset(DatasetOp):
|
|||
def __init__(self, input_dataset, count):
|
||||
super().__init__()
|
||||
self.count = count
|
||||
self.input.append(input_dataset)
|
||||
input_dataset.output.append(self)
|
||||
self.children.append(input_dataset)
|
||||
input_dataset.parent.append(self)
|
||||
self._input_indexs = input_dataset.input_indexs
|
||||
|
||||
def get_args(self):
|
||||
|
@ -2107,7 +2107,7 @@ class TakeDataset(DatasetOp):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
child_size = self.input[0].get_dataset_size()
|
||||
child_size = self.children[0].get_dataset_size()
|
||||
if child_size < self.count:
|
||||
return child_size
|
||||
return self.count
|
||||
|
@ -2131,8 +2131,8 @@ class ZipDataset(DatasetOp):
|
|||
raise TypeError("The parameter %s of zip has type error!" % (dataset))
|
||||
self.datasets = datasets
|
||||
for data in datasets:
|
||||
self.input.append(data)
|
||||
data.output.append(self)
|
||||
self.children.append(data)
|
||||
data.parent.append(self)
|
||||
|
||||
def get_dataset_size(self):
|
||||
"""
|
||||
|
@ -2141,7 +2141,7 @@ class ZipDataset(DatasetOp):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
children_sizes = [c.get_dataset_size() for c in self.input]
|
||||
children_sizes = [c.get_dataset_size() for c in self.children]
|
||||
if all(c is not None for c in children_sizes):
|
||||
return min(children_sizes)
|
||||
return None
|
||||
|
@ -2156,7 +2156,7 @@ class ZipDataset(DatasetOp):
|
|||
return None
|
||||
|
||||
def is_sync(self):
|
||||
return any([c.is_sync() for c in self.input])
|
||||
return any([c.is_sync() for c in self.children])
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
|
@ -2181,8 +2181,8 @@ class ConcatDataset(DatasetOp):
|
|||
raise TypeError("The parameter %s of concat has type error!" % (dataset))
|
||||
self.datasets = datasets
|
||||
for data in datasets:
|
||||
self.input.append(data)
|
||||
data.output.append(self)
|
||||
self.children.append(data)
|
||||
data.parent.append(self)
|
||||
|
||||
def get_dataset_size(self):
|
||||
"""
|
||||
|
@ -2191,7 +2191,7 @@ class ConcatDataset(DatasetOp):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
children_sizes = [c.get_dataset_size() for c in self.input]
|
||||
children_sizes = [c.get_dataset_size() for c in self.children]
|
||||
dataset_size = sum(children_sizes)
|
||||
return dataset_size
|
||||
|
||||
|
@ -2214,8 +2214,8 @@ class RenameDataset(DatasetOp):
|
|||
output_columns = [output_columns]
|
||||
self.input_column_names = input_columns
|
||||
self.output_column_names = output_columns
|
||||
self.input.append(input_dataset)
|
||||
input_dataset.output.append(self)
|
||||
self.children.append(input_dataset)
|
||||
input_dataset.parent.append(self)
|
||||
self._input_indexs = input_dataset.input_indexs
|
||||
|
||||
def get_args(self):
|
||||
|
@ -2241,10 +2241,10 @@ class ProjectDataset(DatasetOp):
|
|||
if not isinstance(columns, list):
|
||||
columns = [columns]
|
||||
self.columns = columns
|
||||
self.input.append(input_dataset)
|
||||
self.children.append(input_dataset)
|
||||
self.prefetch_size = prefetch_size
|
||||
|
||||
input_dataset.output.append(self)
|
||||
input_dataset.parent.append(self)
|
||||
self._input_indexs = input_dataset.input_indexs
|
||||
|
||||
def get_args(self):
|
||||
|
@ -2268,8 +2268,8 @@ class TransferDataset(DatasetOp):
|
|||
|
||||
def __init__(self, input_dataset, queue_name, device_id, device_type, num_batch=None):
|
||||
super().__init__()
|
||||
self.input.append(input_dataset)
|
||||
input_dataset.output.append(self)
|
||||
self.children.append(input_dataset)
|
||||
input_dataset.parent.append(self)
|
||||
self.queue_name = queue_name
|
||||
self._input_indexs = input_dataset.input_indexs
|
||||
self._device_type = device_type
|
||||
|
@ -3171,8 +3171,8 @@ class GeneratorDataset(MappableDataset):
|
|||
cls = self.__class__
|
||||
new_op = cls.__new__(cls)
|
||||
memodict[id(self)] = new_op
|
||||
new_op.input = copy.deepcopy(self.input, memodict)
|
||||
new_op.output = copy.deepcopy(self.output, memodict)
|
||||
new_op.children = copy.deepcopy(self.children, memodict)
|
||||
new_op.parent = copy.deepcopy(self.parent, memodict)
|
||||
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
|
||||
new_op.column_types = copy.deepcopy(self.column_types, memodict)
|
||||
new_op.column_names = copy.deepcopy(self.column_names, memodict)
|
||||
|
@ -4880,14 +4880,14 @@ class BuildVocabDataset(DatasetOp):
|
|||
prefetch_size=None):
|
||||
super().__init__()
|
||||
self.columns = columns
|
||||
self.input.append(input_dataset)
|
||||
self.children.append(input_dataset)
|
||||
self.prefetch_size = prefetch_size
|
||||
self.vocab = vocab
|
||||
self.freq_range = freq_range
|
||||
self.top_k = top_k
|
||||
self.special_tokens = special_tokens
|
||||
self.special_first = special_first
|
||||
input_dataset.output.append(self)
|
||||
input_dataset.parent.append(self)
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
|
@ -4906,11 +4906,11 @@ class BuildVocabDataset(DatasetOp):
|
|||
cls = self.__class__
|
||||
new_op = cls.__new__(cls)
|
||||
memodict[id(self)] = new_op
|
||||
new_op.input = copy.deepcopy(self.input, memodict)
|
||||
new_op.children = copy.deepcopy(self.children, memodict)
|
||||
new_op.columns = copy.deepcopy(self.columns, memodict)
|
||||
new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
|
||||
new_op.prefetch_size = copy.deepcopy(self.prefetch_size, memodict)
|
||||
new_op.output = copy.deepcopy(self.output, memodict)
|
||||
new_op.parent = copy.deepcopy(self.parent, memodict)
|
||||
new_op.freq_range = copy.deepcopy(self.freq_range, memodict)
|
||||
new_op.top_k = copy.deepcopy(self.top_k, memodict)
|
||||
new_op.vocab = self.vocab
|
||||
|
|
|
@ -38,13 +38,13 @@ def _cleanup():
|
|||
|
||||
def alter_tree(node):
|
||||
"""Traversing the python Dataset tree/graph to perform some alteration to some specific nodes."""
|
||||
if not node.input:
|
||||
if not node.children:
|
||||
return _alter_node(node)
|
||||
|
||||
converted_children = []
|
||||
for input_op in node.input:
|
||||
for input_op in node.children:
|
||||
converted_children.append(alter_tree(input_op))
|
||||
node.input = converted_children
|
||||
node.children = converted_children
|
||||
return _alter_node(node)
|
||||
|
||||
|
||||
|
@ -86,14 +86,14 @@ class Iterator:
|
|||
|
||||
def __is_tree_node(self, node):
|
||||
"""Check if a node is tree node."""
|
||||
if not node.input:
|
||||
if len(node.output) > 1:
|
||||
if not node.children:
|
||||
if len(node.parent) > 1:
|
||||
return False
|
||||
|
||||
if len(node.output) > 1:
|
||||
if len(node.parent) > 1:
|
||||
return False
|
||||
|
||||
for input_node in node.input:
|
||||
for input_node in node.children:
|
||||
cls = self.__is_tree_node(input_node)
|
||||
if not cls:
|
||||
return False
|
||||
|
@ -174,7 +174,7 @@ class Iterator:
|
|||
op_type = self.__get_dataset_type(node)
|
||||
c_node = self.depipeline.AddNodeToTree(op_type, node.get_args())
|
||||
|
||||
for py_child in node.input:
|
||||
for py_child in node.children:
|
||||
c_child = self.__convert_node_postorder(py_child)
|
||||
self.depipeline.AddChildToParentNode(c_child, c_node)
|
||||
|
||||
|
@ -184,7 +184,7 @@ class Iterator:
|
|||
"""Recursively get batch node in the dataset tree."""
|
||||
if isinstance(dataset, de.BatchDataset):
|
||||
return
|
||||
for input_op in dataset.input:
|
||||
for input_op in dataset.children:
|
||||
self.__batch_node(input_op, level + 1)
|
||||
|
||||
@staticmethod
|
||||
|
@ -194,11 +194,11 @@ class Iterator:
|
|||
ptr = hex(id(dataset))
|
||||
for _ in range(level):
|
||||
logger.info("\t", end='')
|
||||
if not dataset.input:
|
||||
if not dataset.children:
|
||||
logger.info("-%s (%s)", name, ptr)
|
||||
else:
|
||||
logger.info("+%s (%s)", name, ptr)
|
||||
for input_op in dataset.input:
|
||||
for input_op in dataset.children:
|
||||
Iterator.__print_local(input_op, level + 1)
|
||||
|
||||
def print(self):
|
||||
|
|
|
@ -182,11 +182,11 @@ def traverse(node):
|
|||
node_repr['shard_id'] = None
|
||||
|
||||
# Leaf node doesn't have input attribute.
|
||||
if not node.input:
|
||||
if not node.children:
|
||||
return node_repr
|
||||
|
||||
# Recursively traverse the child and assign it to the current node_repr['children'].
|
||||
for child in node.input:
|
||||
for child in node.children:
|
||||
node_repr["children"].append(traverse(child))
|
||||
|
||||
return node_repr
|
||||
|
@ -226,11 +226,11 @@ def construct_pipeline(node):
|
|||
# Instantiate python Dataset object based on the current dictionary element
|
||||
dataset = create_node(node)
|
||||
# Initially it is not connected to any other object.
|
||||
dataset.input = []
|
||||
dataset.children = []
|
||||
|
||||
# Construct the children too and add edge between the children and parent.
|
||||
for child in node['children']:
|
||||
dataset.input.append(construct_pipeline(child))
|
||||
dataset.children.append(construct_pipeline(child))
|
||||
|
||||
return dataset
|
||||
|
||||
|
|
|
@ -103,7 +103,7 @@ def test_tree_copy():
|
|||
itr = data1.create_tuple_iterator()
|
||||
|
||||
assert id(data1) != id(itr.dataset)
|
||||
assert id(data) != id(itr.dataset.input[0])
|
||||
assert id(data) != id(itr.dataset.children[0])
|
||||
assert id(data1.operations[0]) == id(itr.dataset.operations[0])
|
||||
|
||||
itr.release()
|
||||
|
|
Loading…
Reference in New Issue