forked from mindspore-Ecosystem/mindspore
!434 Bug in cleaning dataset iterators
Merge pull request !434 from h.farahat/multi_itr_bug
This commit is contained in:
commit
9e1b5efd1d
|
@ -28,10 +28,10 @@ ITERATORS_LIST = list()
|
|||
|
||||
|
||||
def _cleanup():
|
||||
for itr in ITERATORS_LIST:
|
||||
iter_ref = itr()
|
||||
for itr_ref in ITERATORS_LIST:
|
||||
itr = itr_ref()
|
||||
if itr is not None:
|
||||
iter_ref.release()
|
||||
itr.release()
|
||||
|
||||
|
||||
def alter_tree(node):
|
||||
|
|
|
@ -13,8 +13,10 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.dataset.engine.iterators import ITERATORS_LIST, _cleanup
|
||||
|
||||
DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
|
||||
SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
|
||||
|
@ -41,3 +43,41 @@ def test_case_iterator():
|
|||
check(COLUMNS[0:7])
|
||||
check(COLUMNS[7:8])
|
||||
check(COLUMNS[0:2:8])
|
||||
|
||||
|
||||
def test_iterator_weak_ref():
|
||||
ITERATORS_LIST.clear()
|
||||
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
|
||||
itr1 = data.create_tuple_iterator()
|
||||
itr2 = data.create_tuple_iterator()
|
||||
itr3 = data.create_tuple_iterator()
|
||||
|
||||
assert len(ITERATORS_LIST) == 3
|
||||
assert sum(itr() is not None for itr in ITERATORS_LIST) == 3
|
||||
|
||||
del itr1
|
||||
assert len(ITERATORS_LIST) == 3
|
||||
assert sum(itr() is not None for itr in ITERATORS_LIST) == 2
|
||||
|
||||
del itr2
|
||||
assert len(ITERATORS_LIST) == 3
|
||||
assert sum(itr() is not None for itr in ITERATORS_LIST) == 1
|
||||
|
||||
del itr3
|
||||
assert len(ITERATORS_LIST) == 3
|
||||
assert sum(itr() is not None for itr in ITERATORS_LIST) == 0
|
||||
|
||||
itr1 = data.create_tuple_iterator()
|
||||
itr2 = data.create_tuple_iterator()
|
||||
itr3 = data.create_tuple_iterator()
|
||||
|
||||
_cleanup()
|
||||
with pytest.raises(AttributeError) as info:
|
||||
itr2.get_next()
|
||||
assert "object has no attribute 'depipeline'" in str(info.value)
|
||||
|
||||
del itr1
|
||||
assert len(ITERATORS_LIST) == 6
|
||||
assert sum(itr() is not None for itr in ITERATORS_LIST) == 2
|
||||
|
||||
_cleanup()
|
||||
|
|
Loading…
Reference in New Issue