forked from mindspore-Ecosystem/mindspore
fix validation errors, and fix try catch error tests
This commit is contained in:
parent
089623ad19
commit
05b2a57d2a
|
@ -25,7 +25,7 @@ from mindspore._c_expression import typing
|
|||
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
|
||||
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
|
||||
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
|
||||
check_columns, check_positive, check_pos_int32
|
||||
check_columns, check_pos_int32
|
||||
|
||||
from . import datasets
|
||||
from . import samplers
|
||||
|
@ -319,10 +319,9 @@ def check_generatordataset(method):
|
|||
# These two parameters appear together.
|
||||
raise ValueError("num_shards and shard_id need to be passed in together")
|
||||
if num_shards is not None:
|
||||
type_check(num_shards, (int,), "num_shards")
|
||||
check_positive(num_shards, "num_shards")
|
||||
check_pos_int32(num_shards, "num_shards")
|
||||
if shard_id >= num_shards:
|
||||
raise ValueError("shard_id should be less than num_shards")
|
||||
raise ValueError("shard_id should be less than num_shards.")
|
||||
|
||||
sampler = param_dict.get("sampler")
|
||||
if sampler is not None:
|
||||
|
@ -417,7 +416,7 @@ def check_bucket_batch_by_length(method):
|
|||
|
||||
all_non_negative = all(item > 0 for item in bucket_boundaries)
|
||||
if not all_non_negative:
|
||||
raise ValueError("bucket_boundaries cannot contain any negative numbers.")
|
||||
raise ValueError("bucket_boundaries must only contain positive numbers.")
|
||||
|
||||
for i in range(len(bucket_boundaries) - 1):
|
||||
if not bucket_boundaries[i + 1] > bucket_boundaries[i]:
|
||||
|
@ -1044,7 +1043,8 @@ def check_numpyslicesdataset(method):
|
|||
|
||||
data = param_dict.get("data")
|
||||
column_names = param_dict.get("column_names")
|
||||
|
||||
if not data:
|
||||
raise ValueError("Argument data cannot be empty")
|
||||
type_check(data, (list, tuple, dict, np.ndarray), "data")
|
||||
if isinstance(data, tuple):
|
||||
type_check(data[0], (list, np.ndarray), "data[0]")
|
||||
|
|
|
@ -62,7 +62,8 @@ def check_from_file(method):
|
|||
def new_method(self, *args, **kwargs):
|
||||
[file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args,
|
||||
**kwargs)
|
||||
check_unique_list_of_words(special_tokens, "special_tokens")
|
||||
if special_tokens is not None:
|
||||
check_unique_list_of_words(special_tokens, "special_tokens")
|
||||
type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"])
|
||||
if vocab_size is not None:
|
||||
check_value(vocab_size, (-1, INT32_MAX), "vocab_size")
|
||||
|
|
|
@ -45,6 +45,7 @@ def test_bucket_batch_invalid_input():
|
|||
bucket_boundaries = [1, 2, 3]
|
||||
empty_bucket_boundaries = []
|
||||
invalid_bucket_boundaries = ["1", "2", "3"]
|
||||
zero_start_bucket_boundaries = [0, 2, 3]
|
||||
negative_bucket_boundaries = [1, 2, -3]
|
||||
decreasing_bucket_boundaries = [3, 2, 1]
|
||||
non_increasing_bucket_boundaries = [1, 2, 2]
|
||||
|
@ -69,9 +70,13 @@ def test_bucket_batch_invalid_input():
|
|||
_ = dataset.bucket_batch_by_length(column_names, invalid_bucket_boundaries, bucket_batch_sizes)
|
||||
assert "bucket_boundaries should be a list of int" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
_ = dataset.bucket_batch_by_length(column_names, zero_start_bucket_boundaries, bucket_batch_sizes)
|
||||
assert "bucket_boundaries must only contain positive numbers." in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
_ = dataset.bucket_batch_by_length(column_names, negative_bucket_boundaries, bucket_batch_sizes)
|
||||
assert "bucket_boundaries cannot contain any negative numbers" in str(info.value)
|
||||
assert "bucket_boundaries must only contain positive numbers." in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
_ = dataset.bucket_batch_by_length(column_names, decreasing_bucket_boundaries, bucket_batch_sizes)
|
||||
|
|
|
@ -108,7 +108,7 @@ def test_concatenate_op_type_mismatch():
|
|||
with pytest.raises(RuntimeError) as error_info:
|
||||
for _ in data:
|
||||
pass
|
||||
assert "Tensor types do not match" in repr(error_info.value)
|
||||
assert "Tensor types do not match" in str(error_info.value)
|
||||
|
||||
|
||||
def test_concatenate_op_type_mismatch2():
|
||||
|
@ -123,7 +123,7 @@ def test_concatenate_op_type_mismatch2():
|
|||
with pytest.raises(RuntimeError) as error_info:
|
||||
for _ in data:
|
||||
pass
|
||||
assert "Tensor types do not match" in repr(error_info.value)
|
||||
assert "Tensor types do not match" in str(error_info.value)
|
||||
|
||||
|
||||
def test_concatenate_op_incorrect_dim():
|
||||
|
@ -138,13 +138,13 @@ def test_concatenate_op_incorrect_dim():
|
|||
with pytest.raises(RuntimeError) as error_info:
|
||||
for _ in data:
|
||||
pass
|
||||
assert "Only 1D tensors supported" in repr(error_info.value)
|
||||
assert "Only 1D tensors supported" in str(error_info.value)
|
||||
|
||||
|
||||
def test_concatenate_op_wrong_axis():
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
data_trans.Concatenate(2)
|
||||
assert "only 1D concatenation supported." in repr(error_info.value)
|
||||
assert "only 1D concatenation supported." in str(error_info.value)
|
||||
|
||||
|
||||
def test_concatenate_op_negative_axis():
|
||||
|
@ -167,7 +167,7 @@ def test_concatenate_op_incorrect_input_dim():
|
|||
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
data_trans.Concatenate(0, prepend_tensor)
|
||||
assert "can only prepend 1D arrays." in repr(error_info.value)
|
||||
assert "can only prepend 1D arrays." in str(error_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -12,12 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import sys
|
||||
import pytest
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import mindspore.dataset as de
|
||||
from mindspore import log as logger
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def test_numpy_slices_list_1():
|
||||
|
@ -173,6 +174,25 @@ def test_numpy_slices_distributed_sampler():
|
|||
assert sum([1 for _ in ds]) == 2
|
||||
|
||||
|
||||
def test_numpy_slices_distributed_shard_limit():
|
||||
logger.info("Test Slicing a 1D list.")
|
||||
|
||||
np_data = [1, 2, 3]
|
||||
num = sys.maxsize
|
||||
with pytest.raises(ValueError) as err:
|
||||
de.NumpySlicesDataset(np_data, num_shards=num, shard_id=0, shuffle=False)
|
||||
assert "Input num_shards is not within the required interval of (1 to 2147483647)." in str(err.value)
|
||||
|
||||
|
||||
def test_numpy_slices_distributed_zero_shard():
|
||||
logger.info("Test Slicing a 1D list.")
|
||||
|
||||
np_data = [1, 2, 3]
|
||||
with pytest.raises(ValueError) as err:
|
||||
de.NumpySlicesDataset(np_data, num_shards=0, shard_id=0, shuffle=False)
|
||||
assert "Input num_shards is not within the required interval of (1 to 2147483647)." in str(err.value)
|
||||
|
||||
|
||||
def test_numpy_slices_sequential_sampler():
|
||||
logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.")
|
||||
|
||||
|
@ -210,6 +230,15 @@ def test_numpy_slices_invalid_empty_column_names():
|
|||
assert "column_names should not be empty" in str(err.value)
|
||||
|
||||
|
||||
def test_numpy_slices_invalid_empty_data_column():
|
||||
logger.info("Test incorrect column_names input")
|
||||
np_data = []
|
||||
|
||||
with pytest.raises(ValueError) as err:
|
||||
de.NumpySlicesDataset(np_data, shuffle=False)
|
||||
assert "Argument data cannot be empty" in str(err.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_numpy_slices_list_1()
|
||||
test_numpy_slices_list_2()
|
||||
|
@ -223,7 +252,10 @@ if __name__ == "__main__":
|
|||
test_numpy_slices_csv_dict()
|
||||
test_numpy_slices_num_samplers()
|
||||
test_numpy_slices_distributed_sampler()
|
||||
test_numpy_slices_distributed_shard_limit()
|
||||
test_numpy_slices_distributed_zero_shard()
|
||||
test_numpy_slices_sequential_sampler()
|
||||
test_numpy_slices_invalid_column_names_type()
|
||||
test_numpy_slices_invalid_column_names_string()
|
||||
test_numpy_slices_invalid_empty_column_names()
|
||||
test_numpy_slices_invalid_empty_data_column()
|
||||
|
|
|
@ -82,9 +82,9 @@ def test_fillop_error_handling():
|
|||
data = data.map(input_columns=["col"], operations=fill_op)
|
||||
|
||||
with pytest.raises(RuntimeError) as error_info:
|
||||
for data_row in data:
|
||||
print(data_row)
|
||||
assert "Types do not match" in repr(error_info.value)
|
||||
for _ in data:
|
||||
pass
|
||||
assert "Types do not match" in str(error_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -189,7 +189,7 @@ def test_minddataset_invalidate_num_shards():
|
|||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info)
|
||||
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info)
|
||||
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
@ -203,7 +203,7 @@ def test_minddataset_invalidate_shard_id():
|
|||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info)
|
||||
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info)
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
||||
|
@ -217,14 +217,14 @@ def test_minddataset_shard_id_bigger_than_num_shard():
|
|||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info)
|
||||
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info)
|
||||
|
||||
with pytest.raises(Exception) as error_info:
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info)
|
||||
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info)
|
||||
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
|
|
@ -39,8 +39,27 @@ def test_on_tokenized_line():
|
|||
res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14],
|
||||
[11, 1, 12, 1, 10, 1, 14, 1, 13, 1, 15]], dtype=np.int32)
|
||||
for i, d in enumerate(data.create_dict_iterator()):
|
||||
_ = (np.testing.assert_array_equal(d["text"], res[i]), i)
|
||||
np.testing.assert_array_equal(d["text"], res[i])
|
||||
|
||||
|
||||
def test_on_tokenized_line_with_no_special_tokens():
|
||||
data = ds.TextFileDataset("../data/dataset/testVocab/lines.txt", shuffle=False)
|
||||
jieba_op = text.JiebaTokenizer(HMM_FILE, MP_FILE, mode=text.JiebaMode.MP)
|
||||
with open(VOCAB_FILE, 'r') as f:
|
||||
for line in f:
|
||||
word = line.split(',')[0]
|
||||
jieba_op.add_word(word)
|
||||
|
||||
data = data.map(input_columns=["text"], operations=jieba_op)
|
||||
vocab = text.Vocab.from_file(VOCAB_FILE, ",")
|
||||
lookup = text.Lookup(vocab, "not")
|
||||
data = data.map(input_columns=["text"], operations=lookup)
|
||||
res = np.array([[8, 0, 9, 0, 10, 0, 13, 0, 11, 0, 12],
|
||||
[9, 0, 10, 0, 8, 0, 12, 0, 11, 0, 13]], dtype=np.int32)
|
||||
for i, d in enumerate(data.create_dict_iterator()):
|
||||
np.testing.assert_array_equal(d["text"], res[i])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_on_tokenized_line()
|
||||
test_on_tokenized_line_with_no_special_tokens()
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
|
@ -163,7 +163,6 @@ def test_sync_exception_01():
|
|||
"""
|
||||
logger.info("test_sync_exception_01")
|
||||
shuffle_size = 4
|
||||
batch_size = 10
|
||||
|
||||
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
||||
|
||||
|
@ -171,11 +170,9 @@ def test_sync_exception_01():
|
|||
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
||||
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
||||
|
||||
try:
|
||||
dataset = dataset.shuffle(shuffle_size)
|
||||
except Exception as e:
|
||||
assert "shuffle" in str(e)
|
||||
dataset = dataset.batch(batch_size)
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
dataset.shuffle(shuffle_size)
|
||||
assert "No shuffle after sync operators" in str(e.value)
|
||||
|
||||
|
||||
def test_sync_exception_02():
|
||||
|
@ -183,7 +180,6 @@ def test_sync_exception_02():
|
|||
Test sync: with duplicated condition name
|
||||
"""
|
||||
logger.info("test_sync_exception_02")
|
||||
batch_size = 6
|
||||
|
||||
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
||||
|
||||
|
@ -192,11 +188,9 @@ def test_sync_exception_02():
|
|||
|
||||
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
||||
|
||||
try:
|
||||
dataset = dataset.sync_wait(num_batch=2, condition_name="every batch")
|
||||
except Exception as e:
|
||||
assert "name" in str(e)
|
||||
dataset = dataset.batch(batch_size)
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
dataset.sync_wait(num_batch=2, condition_name="every batch")
|
||||
assert "Condition name is already in use" in str(e.value)
|
||||
|
||||
|
||||
def test_sync_exception_03():
|
||||
|
@ -209,12 +203,9 @@ def test_sync_exception_03():
|
|||
|
||||
aug = Augment(0)
|
||||
# try to create dataset with batch_size < 0
|
||||
try:
|
||||
dataset = dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update)
|
||||
except Exception as e:
|
||||
assert "num_batch" in str(e)
|
||||
|
||||
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
||||
with pytest.raises(ValueError) as e:
|
||||
dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update)
|
||||
assert "num_batch need to be greater than 0." in str(e.value)
|
||||
|
||||
|
||||
def test_sync_exception_04():
|
||||
|
@ -230,14 +221,13 @@ def test_sync_exception_04():
|
|||
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
|
||||
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
||||
count = 0
|
||||
try:
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
count += 1
|
||||
data = {"loss": count}
|
||||
# dataset.disable_sync()
|
||||
dataset.sync_update(condition_name="every batch", num_batch=-1, data=data)
|
||||
except Exception as e:
|
||||
assert "batch" in str(e)
|
||||
assert "Sync_update batch size can only be positive" in str(e.value)
|
||||
|
||||
|
||||
def test_sync_exception_05():
|
||||
"""
|
||||
|
@ -251,15 +241,15 @@ def test_sync_exception_05():
|
|||
# try to create dataset with batch_size < 0
|
||||
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
|
||||
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
||||
try:
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
dataset.disable_sync()
|
||||
count += 1
|
||||
data = {"loss": count}
|
||||
dataset.disable_sync()
|
||||
dataset.sync_update(condition_name="every", data=data)
|
||||
except Exception as e:
|
||||
assert "name" in str(e)
|
||||
assert "Condition name not found" in str(e.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_simple_sync_wait()
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
Testing UniformAugment in DE
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
|
@ -164,14 +165,13 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2):
|
|||
C.RandomRotation(degrees=45),
|
||||
F.Invert()]
|
||||
|
||||
try:
|
||||
with pytest.raises(TypeError) as e:
|
||||
_ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
|
||||
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Argument tensor_op_5 with value" \
|
||||
" <mindspore.dataset.transforms.vision.py_transforms.Invert" in str(e)
|
||||
assert "is not of type (<class 'mindspore._c_dataengine.TensorOp'>,)" in str(e)
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Argument tensor_op_5 with value" \
|
||||
" <mindspore.dataset.transforms.vision.py_transforms.Invert" in str(e.value)
|
||||
assert "is not of type (<class 'mindspore._c_dataengine.TensorOp'>,)" in str(e.value)
|
||||
|
||||
|
||||
def test_cpp_uniform_augment_exception_large_numops(num_ops=6):
|
||||
|
|
Loading…
Reference in New Issue