Updated concat-zip check and UT.

This commit is contained in:
Rescue 2021-11-17 21:34:07 +00:00
parent 70363899e7
commit da986710f0
2 changed files with 82 additions and 14 deletions

View File

@ -24,26 +24,40 @@ import mindspore.ops.composite as C
from mindspore.ops import operations as P
def check_concat_zip_dataset(dataset):
"""
Check if dataset is concatenated or zipped.
"""
while dataset:
if len(dataset.children) > 1:
return True
if dataset.children:
dataset = dataset.children[0]
continue
dataset = dataset.children
return False
def check_map_offload(dataset):
"""
Check if offload flag is set in data pipeline map ops.
"""
offload_ckeck = False
dataset_tmp = dataset
while dataset_tmp:
if hasattr(dataset_tmp, 'offload'):
if dataset_tmp.offload is True:
offload_ckeck = True
if dataset_tmp.children:
dataset_tmp = dataset_tmp.children[0]
continue
dataset_tmp = dataset_tmp.children
offload_check = False
concat_zip_check = check_concat_zip_dataset(dataset)
while dataset:
if hasattr(dataset, 'offload'):
if dataset.offload is True:
offload_check = True
break
if dataset.children:
dataset = dataset.children[0]
else:
dataset = []
if offload_ckeck is True:
if len(dataset.children) > 1:
raise RuntimeError("Offload currently does not support concatenated datasets.")
if offload_check and concat_zip_check:
raise RuntimeError("Offload module currently does not support concatenated or zipped datasets.")
return offload_ckeck
return offload_check
def apply_offload_iterators(data, offload_model):

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
@ -68,6 +69,59 @@ def test_auto_offload():
np.testing.assert_array_equal(img_0, img_1)
def test_offload_concat_dataset_1():
"""
Feature: test map offload flag for concatenated dataset.
Description: Input is image dataset.
Expectation: Should raise RuntimeError.
"""
# Dataset with offload activated.
dataset_0 = ds.ImageFolderDataset(DATA_DIR)
dataset_0 = dataset_0.map(operations=[C.Decode()], input_columns="image")
dataset_0 = dataset_0.map(operations=[C.HWC2CHW()], input_columns="image", offload=True)
dataset_0 = dataset_0.batch(8, drop_remainder=True)
# Dataset with offload not activated.
dataset_1 = ds.ImageFolderDataset(DATA_DIR)
dataset_1 = dataset_1.map(operations=[C.Decode()], input_columns="image")
dataset_1 = dataset_1.map(operations=[C.HWC2CHW()], input_columns="image")
dataset_1 = dataset_1.batch(8, drop_remainder=True)
dataset_concat = dataset_0 + dataset_1
error_msg = "Offload module currently does not support concatenated or zipped datasets."
with pytest.raises(RuntimeError, match=error_msg):
for (_, _) in dataset_concat.create_tuple_iterator(num_epochs=1, output_numpy=True):
continue
def test_offload_concat_dataset_2():
"""
Feature: test map offload flag for concatenated dataset.
Description: Input is image dataset.
Expectation: Should raise RuntimeError.
"""
# Dataset with offload activated.
dataset_0 = ds.ImageFolderDataset(DATA_DIR)
dataset_0 = dataset_0.map(operations=[C.Decode()], input_columns="image")
dataset_0 = dataset_0.map(operations=[C.HWC2CHW()], input_columns="image", offload=True)
# Dataset with offload not activated.
dataset_1 = ds.ImageFolderDataset(DATA_DIR)
dataset_1 = dataset_1.map(operations=[C.Decode()], input_columns="image")
dataset_1 = dataset_1.map(operations=[C.HWC2CHW()], input_columns="image")
dataset_concat = dataset_0 + dataset_1
dataset_concat = dataset_concat.batch(8, drop_remainder=True)
error_msg = "Offload module currently does not support concatenated or zipped datasets."
with pytest.raises(RuntimeError, match=error_msg):
for (_, _) in dataset_concat.create_tuple_iterator(num_epochs=1, output_numpy=True):
continue
if __name__ == "__main__":
test_offload()
test_auto_offload()
test_offload_concat_dataset_1()
test_offload_concat_dataset_2()