forked from mindspore-Ecosystem/mindspore
Updated concat-zip check and UT.
This commit is contained in:
parent
70363899e7
commit
da986710f0
|
@ -24,26 +24,40 @@ import mindspore.ops.composite as C
|
|||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def check_concat_zip_dataset(dataset):
|
||||
"""
|
||||
Check if dataset is concatenated or zipped.
|
||||
"""
|
||||
while dataset:
|
||||
if len(dataset.children) > 1:
|
||||
return True
|
||||
if dataset.children:
|
||||
dataset = dataset.children[0]
|
||||
continue
|
||||
dataset = dataset.children
|
||||
return False
|
||||
|
||||
|
||||
def check_map_offload(dataset):
|
||||
"""
|
||||
Check if offload flag is set in data pipeline map ops.
|
||||
"""
|
||||
offload_ckeck = False
|
||||
dataset_tmp = dataset
|
||||
while dataset_tmp:
|
||||
if hasattr(dataset_tmp, 'offload'):
|
||||
if dataset_tmp.offload is True:
|
||||
offload_ckeck = True
|
||||
if dataset_tmp.children:
|
||||
dataset_tmp = dataset_tmp.children[0]
|
||||
continue
|
||||
dataset_tmp = dataset_tmp.children
|
||||
offload_check = False
|
||||
concat_zip_check = check_concat_zip_dataset(dataset)
|
||||
while dataset:
|
||||
if hasattr(dataset, 'offload'):
|
||||
if dataset.offload is True:
|
||||
offload_check = True
|
||||
break
|
||||
if dataset.children:
|
||||
dataset = dataset.children[0]
|
||||
else:
|
||||
dataset = []
|
||||
|
||||
if offload_ckeck is True:
|
||||
if len(dataset.children) > 1:
|
||||
raise RuntimeError("Offload currently does not support concatenated datasets.")
|
||||
if offload_check and concat_zip_check:
|
||||
raise RuntimeError("Offload module currently does not support concatenated or zipped datasets.")
|
||||
|
||||
return offload_ckeck
|
||||
return offload_check
|
||||
|
||||
|
||||
def apply_offload_iterators(data, offload_model):
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
|
@ -68,6 +69,59 @@ def test_auto_offload():
|
|||
np.testing.assert_array_equal(img_0, img_1)
|
||||
|
||||
|
||||
def test_offload_concat_dataset_1():
|
||||
"""
|
||||
Feature: test map offload flag for concatenated dataset.
|
||||
Description: Input is image dataset.
|
||||
Expectation: Should raise RuntimeError.
|
||||
"""
|
||||
# Dataset with offload activated.
|
||||
dataset_0 = ds.ImageFolderDataset(DATA_DIR)
|
||||
dataset_0 = dataset_0.map(operations=[C.Decode()], input_columns="image")
|
||||
dataset_0 = dataset_0.map(operations=[C.HWC2CHW()], input_columns="image", offload=True)
|
||||
dataset_0 = dataset_0.batch(8, drop_remainder=True)
|
||||
|
||||
# Dataset with offload not activated.
|
||||
dataset_1 = ds.ImageFolderDataset(DATA_DIR)
|
||||
dataset_1 = dataset_1.map(operations=[C.Decode()], input_columns="image")
|
||||
dataset_1 = dataset_1.map(operations=[C.HWC2CHW()], input_columns="image")
|
||||
dataset_1 = dataset_1.batch(8, drop_remainder=True)
|
||||
|
||||
dataset_concat = dataset_0 + dataset_1
|
||||
|
||||
error_msg = "Offload module currently does not support concatenated or zipped datasets."
|
||||
with pytest.raises(RuntimeError, match=error_msg):
|
||||
for (_, _) in dataset_concat.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
continue
|
||||
|
||||
|
||||
def test_offload_concat_dataset_2():
|
||||
"""
|
||||
Feature: test map offload flag for concatenated dataset.
|
||||
Description: Input is image dataset.
|
||||
Expectation: Should raise RuntimeError.
|
||||
"""
|
||||
# Dataset with offload activated.
|
||||
dataset_0 = ds.ImageFolderDataset(DATA_DIR)
|
||||
dataset_0 = dataset_0.map(operations=[C.Decode()], input_columns="image")
|
||||
dataset_0 = dataset_0.map(operations=[C.HWC2CHW()], input_columns="image", offload=True)
|
||||
|
||||
# Dataset with offload not activated.
|
||||
dataset_1 = ds.ImageFolderDataset(DATA_DIR)
|
||||
dataset_1 = dataset_1.map(operations=[C.Decode()], input_columns="image")
|
||||
dataset_1 = dataset_1.map(operations=[C.HWC2CHW()], input_columns="image")
|
||||
|
||||
dataset_concat = dataset_0 + dataset_1
|
||||
dataset_concat = dataset_concat.batch(8, drop_remainder=True)
|
||||
|
||||
error_msg = "Offload module currently does not support concatenated or zipped datasets."
|
||||
with pytest.raises(RuntimeError, match=error_msg):
|
||||
for (_, _) in dataset_concat.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
continue
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_offload()
|
||||
test_auto_offload()
|
||||
test_offload_concat_dataset_1()
|
||||
test_offload_concat_dataset_2()
|
||||
|
|
Loading…
Reference in New Issue