From 6c37ea3be0efc25374691c4d1c3566690619afc4 Mon Sep 17 00:00:00 2001 From: nhussain Date: Fri, 19 Jun 2020 09:58:38 -0400 Subject: [PATCH] fix validators fixed random_apply tests fix validators fixed random_apply tests fix engine validation --- mindspore/dataset/core/validator_helpers.py | 342 +++++++ mindspore/dataset/engine/validators.py | 922 ++++++------------ mindspore/dataset/text/transforms.py | 2 +- mindspore/dataset/text/utils.py | 33 +- mindspore/dataset/text/validators.py | 316 ++---- mindspore/dataset/transforms/validators.py | 246 +---- .../dataset/transforms/vision/c_transforms.py | 85 +- .../transforms/vision/py_transforms.py | 66 +- .../dataset/transforms/vision/validators.py | 699 ++++--------- .../dataset/test_bounding_box_augment.py | 8 +- .../dataset/test_bucket_batch_by_length.py | 6 +- .../ut/python/dataset/test_concatenate_op.py | 13 +- tests/ut/python/dataset/test_exceptions.py | 10 +- tests/ut/python/dataset/test_from_dataset.py | 18 +- .../dataset/test_linear_transformation.py | 16 +- .../dataset/test_minddataset_exception.py | 14 +- tests/ut/python/dataset/test_ngram_op.py | 8 +- tests/ut/python/dataset/test_normalizeOp.py | 2 +- tests/ut/python/dataset/test_pad_end_op.py | 4 + tests/ut/python/dataset/test_random_affine.py | 17 +- tests/ut/python/dataset/test_random_color.py | 2 +- .../dataset/test_random_crop_and_resize.py | 8 +- .../test_random_crop_and_resize_with_bbox.py | 4 +- .../python/dataset/test_random_grayscale.py | 2 +- .../dataset/test_random_horizontal_flip.py | 4 +- .../test_random_horizontal_flip_with_bbox.py | 2 +- .../python/dataset/test_random_perspective.py | 4 +- .../dataset/test_random_resize_with_bbox.py | 6 +- .../python/dataset/test_random_sharpness.py | 2 +- .../dataset/test_random_vertical_flip.py | 4 +- .../test_random_vertical_flip_with_bbox.py | 2 +- .../python/dataset/test_resize_with_bbox.py | 2 +- tests/ut/python/dataset/test_shuffle.py | 6 +- tests/ut/python/dataset/test_ten_crop.py | 4 +- .../ut/python/dataset/test_uniform_augment.py | 8 +- tests/ut/python/dataset/util.py | 8 +- 36 files changed, 1136 insertions(+), 1759 deletions(-) create mode 100644 mindspore/dataset/core/validator_helpers.py diff --git a/mindspore/dataset/core/validator_helpers.py b/mindspore/dataset/core/validator_helpers.py new file mode 100644 index 00000000000..7a93fcf174b --- /dev/null +++ b/mindspore/dataset/core/validator_helpers.py @@ -0,0 +1,342 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +General Validators. +""" +import inspect +from multiprocessing import cpu_count +import os +import numpy as np +from ..engine import samplers + +# POS_INT_MIN is used to limit values from starting from 0 +POS_INT_MIN = 1 +UINT8_MAX = 255 +UINT8_MIN = 0 +UINT32_MAX = 4294967295 +UINT32_MIN = 0 +UINT64_MAX = 18446744073709551615 +UINT64_MIN = 0 +INT32_MAX = 2147483647 +INT32_MIN = -2147483648 +INT64_MAX = 9223372036854775807 +INT64_MIN = -9223372036854775808 +FLOAT_MAX_INTEGER = 16777216 +FLOAT_MIN_INTEGER = -16777216 +DOUBLE_MAX_INTEGER = 9007199254740992 +DOUBLE_MIN_INTEGER = -9007199254740992 + +valid_detype = [ + "bool", "int8", "int16", "int32", "int64", "uint8", "uint16", + "uint32", "uint64", "float16", "float32", "float64", "string" +] + + +def pad_arg_name(arg_name): + if arg_name != "": + arg_name = arg_name + " " + return arg_name + + +def check_value(value, valid_range, arg_name=""): + arg_name = pad_arg_name(arg_name) + if value < valid_range[0] or value > valid_range[1]: + raise ValueError( + "Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0], + valid_range[1])) + + +def check_range(values, valid_range, arg_name=""): + arg_name = pad_arg_name(arg_name) + if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]: + raise ValueError( + "Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0], + valid_range[1])) + + +def check_positive(value, arg_name=""): + arg_name = pad_arg_name(arg_name) + if value <= 0: + raise ValueError("Input {0}must be greater than 0.".format(arg_name)) + + +def check_positive_float(value, arg_name=""): + arg_name = pad_arg_name(arg_name) + type_check(value, (float,), arg_name) + check_positive(value, arg_name) + + +def check_2tuple(value, arg_name=""): + if not (isinstance(value, tuple) and len(value) == 2): + raise ValueError("Value {0}needs to be a 2-tuple.".format(arg_name)) + + +def check_uint8(value, arg_name=""): + type_check(value, (int,), arg_name) + check_value(value, [UINT8_MIN, UINT8_MAX]) + + +def check_uint32(value, arg_name=""): + type_check(value, (int,), arg_name) + check_value(value, [UINT32_MIN, UINT32_MAX]) + + +def check_pos_int32(value, arg_name=""): + type_check(value, (int,), arg_name) + check_value(value, [POS_INT_MIN, INT32_MAX]) + + +def check_uint64(value, arg_name=""): + type_check(value, (int,), arg_name) + check_value(value, [UINT64_MIN, UINT64_MAX]) + + +def check_pos_int64(value, arg_name=""): + type_check(value, (int,), arg_name) + check_value(value, [UINT64_MIN, INT64_MAX]) + + +def check_pos_float32(value, arg_name=""): + check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name) + + +def check_pos_float64(value, arg_name=""): + check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER], arg_name) + + +def check_valid_detype(type_): + if type_ not in valid_detype: + raise ValueError("Unknown column type") + return True + + +def check_columns(columns, name): + type_check(columns, (list, str), name) + if isinstance(columns, list): + if not columns: + raise ValueError("Column names should not be empty") + col_names = ["col_{0}".format(i) for i in range(len(columns))] + type_check_list(columns, (str,), col_names) + + +def parse_user_args(method, *args, **kwargs): + """ + Parse user arguments in a function + + Args: + method (method): a callable function + *args: user passed args + **kwargs: user passed kwargs + + Returns: + user_filled_args (list): values of what the user passed in for the arguments, + ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed. + """ + sig = inspect.signature(method) + if 'self' in sig.parameters or 'cls' in sig.parameters: + ba = sig.bind(method, *args, **kwargs) + ba.apply_defaults() + params = list(sig.parameters.keys())[1:] + else: + ba = sig.bind(*args, **kwargs) + ba.apply_defaults() + params = list(sig.parameters.keys()) + + user_filled_args = [ba.arguments.get(arg_value) for arg_value in params] + return user_filled_args, ba.arguments + + +def type_check_list(args, types, arg_names): + """ + Check the type of each parameter in the list + + Args: + args (list, tuple): a list or tuple of any variable + types (tuple): tuple of all valid types for arg + arg_names (list, tuple of str): the names of args + + Returns: + Exception: when the type is not correct, otherwise nothing + """ + type_check(args, (list, tuple,), arg_names) + if len(args) != len(arg_names): + raise ValueError("List of arguments is not the same length as argument_names.") + for arg, arg_name in zip(args, arg_names): + type_check(arg, types, arg_name) + + +def type_check(arg, types, arg_name): + """ + Check the type of the parameter + + Args: + arg : any variable + types (tuple): tuple of all valid types for arg + arg_name (str): the name of arg + + Returns: + Exception: when the type is not correct, otherwise nothing + """ + # handle special case of booleans being a subclass of ints + print_value = '\"\"' if repr(arg) == repr('') else arg + + if int in types and bool not in types: + if isinstance(arg, bool): + raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types)) + if not isinstance(arg, types): + raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types)) + + +def check_filename(path): + """ + check the filename in the path + + Args: + path (str): the path + + Returns: + Exception: when error + """ + if not isinstance(path, str): + raise TypeError("path: {} is not string".format(path)) + filename = os.path.basename(path) + + # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`', + # '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>', + # '*', '(', '%', ')', '-', '=', '{', '?', '$' + forbidden_symbols = set(r'\/:*?"<>|`&\';') + + if set(filename) & forbidden_symbols: + raise ValueError(r"filename should not contains \/:*?\"<>|`&;\'") + + if filename.startswith(' ') or filename.endswith(' '): + raise ValueError("filename should not start/end with space") + + return True + + +def check_dir(dataset_dir): + if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK): + raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir)) + + +def check_file(dataset_file): + check_filename(dataset_file) + if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK): + raise ValueError("The file {} does not exist or permission denied!".format(dataset_file)) + + +def check_sampler_shuffle_shard_options(param_dict): + """ + Check for valid shuffle, sampler, num_shards, and shard_id inputs. + Args: + param_dict (dict): param_dict + + Returns: + Exception: ValueError or RuntimeError if error + """ + shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') + num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') + + type_check(sampler, (type(None), samplers.BuiltinSampler, samplers.Sampler), "sampler") + + if sampler is not None: + if shuffle is not None: + raise RuntimeError("sampler and shuffle cannot be specified at the same time.") + + if num_shards is not None: + check_pos_int32(num_shards) + if shard_id is None: + raise RuntimeError("num_shards is specified and currently requires shard_id as well.") + check_value(shard_id, [0, num_shards - 1], "shard_id") + + if num_shards is None and shard_id is not None: + raise RuntimeError("shard_id is specified but num_shards is not.") + + +def check_padding_options(param_dict): + """ + Check for valid padded_sample and num_padded of padded samples + + Args: + param_dict (dict): param_dict + + Returns: + Exception: ValueError or RuntimeError if error + """ + + columns_list = param_dict.get('columns_list') + block_reader = param_dict.get('block_reader') + padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded') + if padded_sample is not None: + if num_padded is None: + raise RuntimeError("padded_sample is specified and requires num_padded as well.") + if num_padded < 0: + raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded)) + if columns_list is None: + raise RuntimeError("padded_sample is specified and requires columns_list as well.") + for column in columns_list: + if column not in padded_sample: + raise ValueError("padded_sample cannot match columns_list.") + if block_reader: + raise RuntimeError("block_reader and padded_sample cannot be specified at the same time.") + + if padded_sample is None and num_padded is not None: + raise RuntimeError("num_padded is specified but padded_sample is not.") + + +def check_num_parallel_workers(value): + type_check(value, (int,), "num_parallel_workers") + if value < 1 or value > cpu_count(): + raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count())) + + +def check_num_samples(value): + type_check(value, (int,), "num_samples") + check_value(value, [0, INT32_MAX], "num_samples") + + +def validate_dataset_param_value(param_list, param_dict, param_type): + for param_name in param_list: + if param_dict.get(param_name) is not None: + if param_name == 'num_parallel_workers': + check_num_parallel_workers(param_dict.get(param_name)) + if param_name == 'num_samples': + check_num_samples(param_dict.get(param_name)) + else: + type_check(param_dict.get(param_name), (param_type,), param_name) + + +def check_gnn_list_or_ndarray(param, param_name): + """ + Check if the input parameter is list or numpy.ndarray. + + Args: + param (list, nd.ndarray): param + param_name (str): param_name + + Returns: + Exception: TypeError if error + """ + + type_check(param, (list, np.ndarray), param_name) + if isinstance(param, list): + param_names = ["param_{0}".format(i) for i in range(len(param))] + type_check_list(param, (int,), param_names) + + elif isinstance(param, np.ndarray): + if not param.dtype == np.int32: + raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( + param_name, param.dtype)) diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index d980245c045..7edf381b2c6 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -9,210 +9,50 @@ # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and +# See the License foNtest_resr the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Built-in validators. +""" +Built-in validators. """ import inspect as ins import os from functools import wraps -from multiprocessing import cpu_count import numpy as np from mindspore._c_expression import typing +from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ + INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ + validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ + check_columns, check_positive from . import datasets from . import samplers -INT32_MAX = 2147483647 -valid_detype = [ - "bool", "int8", "int16", "int32", "int64", "uint8", "uint16", - "uint32", "uint64", "float16", "float32", "float64", "string" -] - - -def check_valid_detype(type_): - if type_ not in valid_detype: - raise ValueError("Unknown column type") - return True - - -def check_filename(path): - """ - check the filename in the path - - Args: - path (str): the path - - Returns: - Exception: when error - """ - if not isinstance(path, str): - raise TypeError("path: {} is not string".format(path)) - filename = os.path.basename(path) - - # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`', - # '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>', - # '*', '(', '%', ')', '-', '=', '{', '?', '$' - forbidden_symbols = set(r'\/:*?"<>|`&\';') - - if set(filename) & forbidden_symbols: - raise ValueError(r"filename should not contains \/:*?\"<>|`&;\'") - - if filename.startswith(' ') or filename.endswith(' '): - raise ValueError("filename should not start/end with space") - - return True - - -def make_param_dict(method, args, kwargs): - """Return a dictionary of the method's args and kwargs.""" - sig = ins.signature(method) - params = sig.parameters - keys = list(params.keys()) - param_dict = dict() - try: - for name, value in enumerate(args): - param_dict[keys[name]] = value - except IndexError: - raise TypeError("{0}() expected {1} arguments, but {2} were given".format( - method.__name__, len(keys) - 1, len(args) - 1)) - - param_dict.update(zip(params.keys(), args)) - param_dict.update(kwargs) - - for name, value in params.items(): - if name not in param_dict: - param_dict[name] = value.default - return param_dict - - -def check_type(param, param_name, valid_type): - if (not isinstance(param, valid_type)) or (valid_type == int and isinstance(param, bool)): - raise TypeError("Wrong input type for {0}, should be {1}, got {2}".format(param_name, valid_type, type(param))) - - -def check_param_type(param_list, param_dict, param_type): - for param_name in param_list: - if param_dict.get(param_name) is not None: - if param_name == 'num_parallel_workers': - check_num_parallel_workers(param_dict.get(param_name)) - if param_name == 'num_samples': - check_num_samples(param_dict.get(param_name)) - else: - check_type(param_dict.get(param_name), param_name, param_type) - - -def check_positive_int32(param, param_name): - check_interval_closed(param, param_name, [1, INT32_MAX]) - - -def check_interval_closed(param, param_name, valid_range): - if param < valid_range[0] or param > valid_range[1]: - raise ValueError("The value of {0} exceeds the closed interval range {1}.".format(param_name, valid_range)) - - -def check_num_parallel_workers(value): - check_type(value, 'num_parallel_workers', int) - if value < 1 or value > cpu_count(): - raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count())) - - -def check_num_samples(value): - check_type(value, 'num_samples', int) - if value < 0: - raise ValueError("num_samples cannot be less than 0!") - - -def check_dataset_dir(dataset_dir): - if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK): - raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir)) - - -def check_dataset_file(dataset_file): - check_filename(dataset_file) - if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK): - raise ValueError("The file {} does not exist or permission denied!".format(dataset_file)) - - -def check_sampler_shuffle_shard_options(param_dict): - """check for valid shuffle, sampler, num_shards, and shard_id inputs.""" - shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') - num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') - - if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)): - raise TypeError("sampler is not a valid Sampler type.") - - if sampler is not None: - if shuffle is not None: - raise RuntimeError("sampler and shuffle cannot be specified at the same time.") - - if num_shards is not None: - raise RuntimeError("sampler and sharding cannot be specified at the same time.") - - if num_shards is not None: - check_positive_int32(num_shards, "num_shards") - if shard_id is None: - raise RuntimeError("num_shards is specified and currently requires shard_id as well.") - if shard_id < 0 or shard_id >= num_shards: - raise ValueError("shard_id is invalid, shard_id={}".format(shard_id)) - - if num_shards is None and shard_id is not None: - raise RuntimeError("shard_id is specified but num_shards is not.") - - -def check_padding_options(param_dict): - """ check for valid padded_sample and num_padded of padded samples""" - columns_list = param_dict.get('columns_list') - block_reader = param_dict.get('block_reader') - padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded') - if padded_sample is not None: - if num_padded is None: - raise RuntimeError("padded_sample is specified and requires num_padded as well.") - if num_padded < 0: - raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded)) - if columns_list is None: - raise RuntimeError("padded_sample is specified and requires columns_list as well.") - for column in columns_list: - if column not in padded_sample: - raise ValueError("padded_sample cannot match columns_list.") - if block_reader: - raise RuntimeError("block_reader and padded_sample cannot be specified at the same time.") - - if padded_sample is None and num_padded is not None: - raise RuntimeError("num_padded is specified but padded_sample is not.") def check_imagefolderdatasetv2(method): """A wrapper that wrap a parameter checker to the original Dataset(ImageFolderDatasetV2).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_bool = ['shuffle', 'decode'] nreq_param_list = ['extensions'] nreq_param_dict = ['class_indexing'] - # check dataset_dir; required argument dataset_dir = param_dict.get('dataset_dir') - if dataset_dir is None: - raise ValueError("dataset_dir is not provided.") - check_dataset_dir(dataset_dir) - - check_param_type(nreq_param_int, param_dict, int) - - check_param_type(nreq_param_bool, param_dict, bool) - - check_param_type(nreq_param_list, param_dict, list) - - check_param_type(nreq_param_dict, param_dict, dict) + check_dir(dataset_dir) + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_list, param_dict, list) + validate_dataset_param_value(nreq_param_dict, param_dict, dict) check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -221,25 +61,21 @@ def check_mnist_cifar_dataset(method): """A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_bool = ['shuffle'] - # check dataset_dir; required argument dataset_dir = param_dict.get('dataset_dir') - if dataset_dir is None: - raise ValueError("dataset_dir is not provided.") - check_dataset_dir(dataset_dir) + check_dir(dataset_dir) - check_param_type(nreq_param_int, param_dict, int) - - check_param_type(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -248,31 +84,25 @@ def check_manifestdataset(method): """A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_bool = ['shuffle', 'decode'] nreq_param_str = ['usage'] nreq_param_dict = ['class_indexing'] - # check dataset_file; required argument dataset_file = param_dict.get('dataset_file') - if dataset_file is None: - raise ValueError("dataset_file is not provided.") - check_dataset_file(dataset_file) + check_file(dataset_file) - check_param_type(nreq_param_int, param_dict, int) - - check_param_type(nreq_param_bool, param_dict, bool) - - check_param_type(nreq_param_str, param_dict, str) - - check_param_type(nreq_param_dict, param_dict, dict) + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_str, param_dict, str) + validate_dataset_param_value(nreq_param_dict, param_dict, dict) check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -281,29 +111,24 @@ def check_tfrecorddataset(method): """A wrapper that wrap a parameter checker to the original Dataset(TFRecordDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_list = ['columns_list'] nreq_param_bool = ['shard_equal_rows'] - # check dataset_files; required argument dataset_files = param_dict.get('dataset_files') - if dataset_files is None: - raise ValueError("dataset_files is not provided.") if not isinstance(dataset_files, (str, list)): raise TypeError("dataset_files should be of type str or a list of strings.") - check_param_type(nreq_param_int, param_dict, int) - - check_param_type(nreq_param_list, param_dict, list) - - check_param_type(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_list, param_dict, list) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -312,32 +137,22 @@ def check_vocdataset(method): """A wrapper that wrap a parameter checker to the original Dataset(VOCDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_bool = ['shuffle', 'decode'] nreq_param_dict = ['class_indexing'] - # check dataset_dir; required argument dataset_dir = param_dict.get('dataset_dir') - if dataset_dir is None: - raise ValueError("dataset_dir is not provided.") - check_dataset_dir(dataset_dir) - # check task; required argument - task = param_dict.get('task') - if task is None: - raise ValueError("task is not provided.") - if not isinstance(task, str): - raise TypeError("task is not str type.") - # check mode; required argument - mode = param_dict.get('mode') - if mode is None: - raise ValueError("mode is not provided.") - if not isinstance(mode, str): - raise TypeError("mode is not str type.") + check_dir(dataset_dir) + + task = param_dict.get('task') + type_check(task, (str,), "task") + + mode = param_dict.get('mode') + type_check(mode, (str,), "mode") - imagesets_file = "" if task == "Segmentation": imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", mode + ".txt") if param_dict.get('class_indexing') is not None: @@ -347,17 +162,14 @@ def check_vocdataset(method): else: raise ValueError("Invalid task : " + task) - check_dataset_file(imagesets_file) - - check_param_type(nreq_param_int, param_dict, int) - - check_param_type(nreq_param_bool, param_dict, bool) - - check_param_type(nreq_param_dict, param_dict, dict) + check_file(imagesets_file) + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_dict, param_dict, dict) check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -366,44 +178,34 @@ def check_cocodataset(method): """A wrapper that wrap a parameter checker to the original Dataset(CocoDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_bool = ['shuffle', 'decode'] - # check dataset_dir; required argument dataset_dir = param_dict.get('dataset_dir') - if dataset_dir is None: - raise ValueError("dataset_dir is not provided.") - check_dataset_dir(dataset_dir) + check_dir(dataset_dir) - # check annotation_file; required argument annotation_file = param_dict.get('annotation_file') - if annotation_file is None: - raise ValueError("annotation_file is not provided.") - check_dataset_file(annotation_file) + check_file(annotation_file) - # check task; required argument task = param_dict.get('task') - if task is None: - raise ValueError("task is not provided.") - if not isinstance(task, str): - raise TypeError("task is not str type.") + type_check(task, (str,), "task") if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}: raise ValueError("Invalid task type") - check_param_type(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_int, param_dict, int) - check_param_type(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) sampler = param_dict.get('sampler') if sampler is not None and isinstance(sampler, samplers.PKSampler): raise ValueError("CocoDataset doesn't support PKSampler") check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -412,27 +214,22 @@ def check_celebadataset(method): """A wrapper that wrap a parameter checker to the original Dataset(CelebADataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_bool = ['shuffle', 'decode'] nreq_param_list = ['extensions'] nreq_param_str = ['dataset_type'] - # check dataset_dir; required argument dataset_dir = param_dict.get('dataset_dir') - if dataset_dir is None: - raise ValueError("dataset_dir is not provided.") - check_dataset_dir(dataset_dir) - check_param_type(nreq_param_int, param_dict, int) + check_dir(dataset_dir) - check_param_type(nreq_param_bool, param_dict, bool) - - check_param_type(nreq_param_list, param_dict, list) - - check_param_type(nreq_param_str, param_dict, str) + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_list, param_dict, list) + validate_dataset_param_value(nreq_param_str, param_dict, str) dataset_type = param_dict.get('dataset_type') if dataset_type is not None and dataset_type not in ('all', 'train', 'valid', 'test'): @@ -444,7 +241,7 @@ def check_celebadataset(method): if sampler is not None and isinstance(sampler, samplers.PKSampler): raise ValueError("CelebADataset does not support PKSampler.") - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -453,36 +250,30 @@ def check_minddataset(method): """A wrapper that wrap a parameter checker to the original Dataset(MindDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id', 'num_padded'] nreq_param_list = ['columns_list'] nreq_param_bool = ['block_reader'] nreq_param_dict = ['padded_sample'] - # check dataset_file; required argument dataset_file = param_dict.get('dataset_file') - if dataset_file is None: - raise ValueError("dataset_file is not provided.") if isinstance(dataset_file, list): for f in dataset_file: - check_dataset_file(f) + check_file(f) else: - check_dataset_file(dataset_file) + check_file(dataset_file) - check_param_type(nreq_param_int, param_dict, int) - - check_param_type(nreq_param_list, param_dict, list) - - check_param_type(nreq_param_bool, param_dict, bool) - - check_param_type(nreq_param_dict, param_dict, dict) + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_list, param_dict, list) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_dict, param_dict, dict) check_sampler_shuffle_shard_options(param_dict) check_padding_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -491,20 +282,17 @@ def check_generatordataset(method): """A wrapper that wrap a parameter checker to the original Dataset(GeneratorDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) - # check generator_function; required argument source = param_dict.get('source') - if source is None: - raise ValueError("source is not provided.") + if not callable(source): try: iter(source) except TypeError: raise TypeError("source should be callable, iterable or random accessible") - # check column_names or schema; required argument column_names = param_dict.get('column_names') if column_names is not None: check_columns(column_names, "column_names") @@ -518,11 +306,11 @@ def check_generatordataset(method): # check optional argument nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"] - check_param_type(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_int, param_dict, int) nreq_param_list = ["column_types"] - check_param_type(nreq_param_list, param_dict, list) + validate_dataset_param_value(nreq_param_list, param_dict, list) nreq_param_bool = ["shuffle"] - check_param_type(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) num_shards = param_dict.get("num_shards") shard_id = param_dict.get("shard_id") @@ -530,7 +318,8 @@ def check_generatordataset(method): # These two parameters appear together. raise ValueError("num_shards and shard_id need to be passed in together") if num_shards is not None: - check_positive_int32(num_shards, "num_shards") + type_check(num_shards, (int,), "num_shards") + check_positive(num_shards, "num_shards") if shard_id >= num_shards: raise ValueError("shard_id should be less than num_shards") @@ -551,67 +340,46 @@ def check_generatordataset(method): if num_shards is not None and not hasattr(source, "__getitem__"): raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'") - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method -def check_batch_size(batch_size): - if not (isinstance(batch_size, int) or (callable(batch_size))): - raise TypeError("batch_size should either be an int or a callable.") - if callable(batch_size): - sig = ins.signature(batch_size) - if len(sig.parameters) != 1: - raise ValueError("batch_size callable should take one parameter (BatchInfo).") - - -def check_count(count): - check_type(count, 'count', int) - if (count <= 0 and count != -1) or count > INT32_MAX: - raise ValueError("count should be either -1 or positive integer.") - - -def check_columns(columns, name): - if isinstance(columns, list): - for column in columns: - if not isinstance(column, str): - raise TypeError("Each column in {0} should be of type str. Got {1}.".format(name, type(column))) - elif not isinstance(columns, str): - raise TypeError("{} should be either a list of strings or a single string.".format(name)) - - def check_pad_info(key, val): """check the key and value pair of pad_info in batch""" - check_type(key, "key in pad_info", str) + type_check(key, (str,), "key in pad_info") + if val is not None: assert len(val) == 2, "value of pad_info should be a tuple of size 2" - check_type(val, "value in pad_info", tuple) + type_check(val, (tuple,), "value in pad_info") + if val[0] is not None: - check_type(val[0], "pad_shape", list) + type_check(val[0], (list,), "pad_shape") + for dim in val[0]: if dim is not None: - check_type(dim, "dim in pad_shape", int) + type_check(dim, (int,), "dim in pad_shape") assert dim > 0, "pad shape should be positive integers" if val[1] is not None: - check_type(val[1], "pad_value", (int, float, str, bytes)) + type_check(val[1], (int, float, str, bytes), "pad_value") def check_bucket_batch_by_length(method): """check the input arguments of bucket_batch_by_length.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [column_names, bucket_boundaries, bucket_batch_sizes, element_length_function, pad_info, + pad_to_bucket_boundary, drop_remainder], _ = parse_user_args(method, *args, **kwargs) nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes'] - check_param_type(nreq_param_list, param_dict, list) + + type_check_list([column_names, bucket_boundaries, bucket_batch_sizes], (list,), nreq_param_list) nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder'] - check_param_type(nbool_param_list, param_dict, bool) + type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list) # check column_names: must be list of string. - column_names = param_dict.get("column_names") - if not column_names: raise ValueError("column_names cannot be empty") @@ -619,13 +387,10 @@ def check_bucket_batch_by_length(method): if not all_string: raise TypeError("column_names should be a list of str.") - element_length_function = param_dict.get("element_length_function") if element_length_function is None and len(column_names) != 1: raise ValueError("If element_length_function is not specified, exactly one column name should be passed.") # check bucket_boundaries: must be list of int, positive and strictly increasing - bucket_boundaries = param_dict.get('bucket_boundaries') - if not bucket_boundaries: raise ValueError("bucket_boundaries cannot be empty.") @@ -633,7 +398,7 @@ def check_bucket_batch_by_length(method): if not all_int: raise TypeError("bucket_boundaries should be a list of int.") - all_non_negative = all(item >= 0 for item in bucket_boundaries) + all_non_negative = all(item > 0 for item in bucket_boundaries) if not all_non_negative: raise ValueError("bucket_boundaries cannot contain any negative numbers.") @@ -642,7 +407,6 @@ def check_bucket_batch_by_length(method): raise ValueError("bucket_boundaries should be strictly increasing.") # check bucket_batch_sizes: must be list of int and positive - bucket_batch_sizes = param_dict.get('bucket_batch_sizes') if len(bucket_batch_sizes) != len(bucket_boundaries) + 1: raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.") @@ -654,12 +418,13 @@ def check_bucket_batch_by_length(method): if not all_non_negative: raise ValueError("bucket_batch_sizes should be a list of positive numbers.") - if param_dict.get('pad_info') is not None: - check_type(param_dict["pad_info"], "pad_info", dict) - for k, v in param_dict.get('pad_info').items(): + if pad_info is not None: + type_check(pad_info, (dict,), "pad_info") + + for k, v in pad_info.items(): check_pad_info(k, v) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -668,37 +433,33 @@ def check_batch(method): """check the input arguments of batch.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [batch_size, drop_remainder, num_parallel_workers, per_batch_map, + input_columns, pad_info], param_dict = parse_user_args(method, *args, **kwargs) - nreq_param_int = ['num_parallel_workers'] - nreq_param_bool = ['drop_remainder'] - nreq_param_columns = ['input_columns'] + if not (isinstance(batch_size, int) or (callable(batch_size))): + raise TypeError("batch_size should either be an int or a callable.") - # check batch_size; required argument - batch_size = param_dict.get("batch_size") - if batch_size is None: - raise ValueError("batch_size is not provided.") - check_batch_size(batch_size) + if callable(batch_size): + sig = ins.signature(batch_size) + if len(sig.parameters) != 1: + raise ValueError("batch_size callable should take one parameter (BatchInfo).") - check_param_type(nreq_param_int, param_dict, int) + if num_parallel_workers is not None: + check_num_parallel_workers(num_parallel_workers) + type_check(drop_remainder, (bool,), "drop_remainder") - check_param_type(nreq_param_bool, param_dict, bool) - - if (param_dict.get('pad_info') is not None) and (param_dict.get('per_batch_map') is not None): + if (pad_info is not None) and (per_batch_map is not None): raise ValueError("pad_info and per_batch_map can't both be set") - if param_dict.get('pad_info') is not None: - check_type(param_dict["pad_info"], "pad_info", dict) + if pad_info is not None: + type_check(param_dict["pad_info"], (dict,), "pad_info") for k, v in param_dict.get('pad_info').items(): check_pad_info(k, v) - for param_name in nreq_param_columns: - param = param_dict.get(param_name) - if param is not None: - check_columns(param, param_name) + if input_columns is not None: + check_columns(input_columns, "input_columns") - per_batch_map, input_columns = param_dict.get('per_batch_map'), param_dict.get('input_columns') if (per_batch_map is None) != (input_columns is None): # These two parameters appear together. raise ValueError("per_batch_map and input_columns need to be passed in together.") @@ -709,43 +470,38 @@ def check_batch(method): if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1): raise ValueError("the signature of per_batch_map should match with input columns") - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method + def check_sync_wait(method): """check the input arguments of sync_wait.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [condition_name, num_batch, _], _ = parse_user_args(method, *args, **kwargs) - nreq_param_str = ['condition_name'] - nreq_param_int = ['step_size'] + type_check(condition_name, (str,), "condition_name") + type_check(num_batch, (int,), "num_batch") - check_param_type(nreq_param_int, param_dict, int) - - check_param_type(nreq_param_str, param_dict, str) - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method + def check_shuffle(method): """check the input arguments of shuffle.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [buffer_size], _ = parse_user_args(method, *args, **kwargs) - # check buffer_size; required argument - buffer_size = param_dict.get("buffer_size") - if buffer_size is None: - raise ValueError("buffer_size is not provided.") - check_type(buffer_size, 'buffer_size', int) - check_interval_closed(buffer_size, 'buffer_size', [2, INT32_MAX]) + type_check(buffer_size, (int,), "buffer_size") - return method(*args, **kwargs) + check_value(buffer_size, [2, INT32_MAX], "buffer_size") + + return method(self, *args, **kwargs) return new_method @@ -754,23 +510,23 @@ def check_map(method): """check the input arguments of map.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing], _ = \ + parse_user_args(method, *args, **kwargs) - nreq_param_list = ['columns_order'] - nreq_param_int = ['num_parallel_workers'] nreq_param_columns = ['input_columns', 'output_columns'] - nreq_param_bool = ['python_multiprocessing'] - check_param_type(nreq_param_list, param_dict, list) - check_param_type(nreq_param_int, param_dict, int) - check_param_type(nreq_param_bool, param_dict, bool) - for param_name in nreq_param_columns: - param = param_dict.get(param_name) + if columns_order is not None: + type_check(columns_order, (list,), "columns_order") + if num_parallel_workers is not None: + check_num_parallel_workers(num_parallel_workers) + type_check(python_multiprocessing, (bool,), "python_multiprocessing") + + for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]): if param is not None: check_columns(param, param_name) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -779,19 +535,20 @@ def check_filter(method): """"check the input arguments of filter.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) - predicate = param_dict.get("predicate") + def new_method(self, *args, **kwargs): + [predicate, input_columns, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs) if not callable(predicate): raise TypeError("Predicate should be a python function or a callable python object.") - nreq_param_int = ['num_parallel_workers'] - check_param_type(nreq_param_int, param_dict, int) - param_name = "input_columns" - param = param_dict.get(param_name) - if param is not None: - check_columns(param, param_name) - return method(*args, **kwargs) + check_num_parallel_workers(num_parallel_workers) + + if num_parallel_workers is not None: + check_num_parallel_workers(num_parallel_workers) + + if input_columns is not None: + check_columns(input_columns, "input_columns") + + return method(self, *args, **kwargs) return new_method @@ -800,14 +557,13 @@ def check_repeat(method): """check the input arguments of repeat.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [count], _ = parse_user_args(method, *args, **kwargs) - count = param_dict.get('count') - if count is not None: - check_count(count) - - return method(*args, **kwargs) + type_check(count, (int, type(None)), "repeat") + if isinstance(count, int): + check_value(count, (-1, INT32_MAX), "count") + return method(self, *args, **kwargs) return new_method @@ -816,15 +572,13 @@ def check_skip(method): """check the input arguments of skip.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [count], _ = parse_user_args(method, *args, **kwargs) - count = param_dict.get('count') - check_type(count, 'count', int) - if count < 0: - raise ValueError("Skip count must be positive integer or 0.") + type_check(count, (int,), "count") + check_value(count, (-1, INT32_MAX), "count") - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -833,13 +587,13 @@ def check_take(method): """check the input arguments of take.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [count], _ = parse_user_args(method, *args, **kwargs) + type_check(count, (int,), "count") + if (count <= 0 and count != -1) or count > INT32_MAX: + raise ValueError("count should be either -1 or positive integer.") - count = param_dict.get('count') - check_count(count) - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -849,13 +603,8 @@ def check_zip(method): @wraps(method) def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) - - # check datasets; required argument - ds = param_dict.get("datasets") - if ds is None: - raise ValueError("datasets is not provided.") - check_type(ds, 'datasets', tuple) + [ds], _ = parse_user_args(method, *args, **kwargs) + type_check(ds, (tuple,), "datasets") return method(*args, **kwargs) @@ -866,18 +615,11 @@ def check_zip_dataset(method): """check the input arguments of zip method in `Dataset`.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [ds], _ = parse_user_args(method, *args, **kwargs) + type_check(ds, (tuple, datasets.Dataset), "datasets") - # check datasets; required argument - ds = param_dict.get("datasets") - if ds is None: - raise ValueError("datasets is not provided.") - - if not isinstance(ds, (tuple, datasets.Dataset)): - raise TypeError("datasets is not tuple or of type Dataset.") - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -886,18 +628,13 @@ def check_concat(method): """check the input arguments of concat method in `Dataset`.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) - - # check datasets; required argument - ds = param_dict.get("datasets") - if ds is None: - raise ValueError("datasets is not provided.") - - if not isinstance(ds, (list, datasets.Dataset)): - raise TypeError("datasets is not list or of type Dataset.") - - return method(*args, **kwargs) + def new_method(self, *args, **kwargs): + [ds], _ = parse_user_args(method, *args, **kwargs) + type_check(ds, (list, datasets.Dataset), "datasets") + if isinstance(ds, list): + dataset_names = ["dataset[{0}]".format(i) for i in range(len(ds)) if isinstance(ds, list)] + type_check_list(ds, (datasets.Dataset,), dataset_names) + return method(self, *args, **kwargs) return new_method @@ -906,26 +643,23 @@ def check_rename(method): """check the input arguments of rename.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + values, _ = parse_user_args(method, *args, **kwargs) req_param_columns = ['input_columns', 'output_columns'] - # check req_param_list; required arguments - for param_name in req_param_columns: - param = param_dict.get(param_name) - if param is None: - raise ValueError("{} is not provided.".format(param_name)) + for param_name, param in zip(req_param_columns, values): check_columns(param, param_name) input_size, output_size = 1, 1 - if isinstance(param_dict.get(req_param_columns[0]), list): - input_size = len(param_dict.get(req_param_columns[0])) - if isinstance(param_dict.get(req_param_columns[1]), list): - output_size = len(param_dict.get(req_param_columns[1])) + input_columns, output_columns = values + if isinstance(input_columns, list): + input_size = len(input_columns) + if isinstance(output_columns, list): + output_size = len(output_columns) if input_size != output_size: raise ValueError("Number of column in input_columns and output_columns is not equal.") - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -934,56 +668,39 @@ def check_project(method): """check the input arguments of project.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) - - # check columns; required argument - columns = param_dict.get("columns") - if columns is None: - raise ValueError("columns is not provided.") + def new_method(self, *args, **kwargs): + [columns], _ = parse_user_args(method, *args, **kwargs) check_columns(columns, 'columns') - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method -def check_shape(shape, name): - if isinstance(shape, list): - for element in shape: - if not isinstance(element, int): - raise TypeError( - "Each element in {0} should be of type int. Got {1}.".format(name, type(element))) - else: - raise TypeError("Expected int list.") - - def check_add_column(method): """check the input arguments of add_column.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [name, de_type, shape], _ = parse_user_args(method, *args, **kwargs) - # check name; required argument - name = param_dict.get("name") - if not isinstance(name, str) or not name: + type_check(name, (str,), "name") + + if not name: raise TypeError("Expected non-empty string.") - # check type; required argument - de_type = param_dict.get("de_type") if de_type is not None: if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type): raise TypeError("Unknown column type.") else: raise TypeError("Expected non-empty string.") - # check shape - shape = param_dict.get("shape") if shape is not None: - check_shape(shape, "shape") + type_check(shape, (list,), "shape") + shape_names = ["shape[{0}]".format(i) for i in range(len(shape))] + type_check_list(shape, (int,), shape_names) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -992,17 +709,13 @@ def check_cluedataset(method): """A wrapper that wrap a parameter checker to the original Dataset(CLUEDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] - # check dataset_files; required argument dataset_files = param_dict.get('dataset_files') - if dataset_files is None: - raise ValueError("dataset_files is not provided.") - if not isinstance(dataset_files, (str, list)): - raise TypeError("dataset_files should be of type str or a list of strings.") + type_check(dataset_files, (str, list), "dataset files") # check task task_param = param_dict.get('task') @@ -1014,11 +727,10 @@ def check_cluedataset(method): if usage_param not in ['train', 'test', 'eval']: raise ValueError("usage should be train, test or eval") - check_param_type(nreq_param_int, param_dict, int) - + validate_dataset_param_value(nreq_param_int, param_dict, int) check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -1027,23 +739,17 @@ def check_textfiledataset(method): """A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] - # check dataset_files; required argument dataset_files = param_dict.get('dataset_files') - if dataset_files is None: - raise ValueError("dataset_files is not provided.") - if not isinstance(dataset_files, (str, list)): - raise TypeError("dataset_files should be of type str or a list of strings.") - - check_param_type(nreq_param_int, param_dict, int) - + type_check(dataset_files, (str, list), "dataset files") + validate_dataset_param_value(nreq_param_int, param_dict, int) check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -1052,19 +758,16 @@ def check_split(method): """check the input arguments of split.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [sizes, randomize], _ = parse_user_args(method, *args, **kwargs) - nreq_param_list = ['sizes'] - nreq_param_bool = ['randomize'] - check_param_type(nreq_param_list, param_dict, list) - check_param_type(nreq_param_bool, param_dict, bool) + type_check(sizes, (list,), "sizes") + type_check(randomize, (bool,), "randomize") # check sizes: must be list of float or list of int - sizes = param_dict.get('sizes') - if not sizes: raise ValueError("sizes cannot be empty.") + all_int = all(isinstance(item, int) for item in sizes) all_float = all(isinstance(item, float) for item in sizes) @@ -1085,7 +788,7 @@ def check_split(method): if not abs(sum(sizes) - 1) < epsilon: raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.") - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -1094,52 +797,26 @@ def check_gnn_graphdata(method): """check the input arguments of graphdata.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [dataset_file, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs) + check_file(dataset_file) - # check dataset_file; required argument - dataset_file = param_dict.get('dataset_file') - if dataset_file is None: - raise ValueError("dataset_file is not provided.") - check_dataset_file(dataset_file) - - nreq_param_int = ['num_parallel_workers'] - - check_param_type(nreq_param_int, param_dict, int) - - return method(*args, **kwargs) + if num_parallel_workers is not None: + type_check(num_parallel_workers, (int,), "num_parallel_workers") + return method(self, *args, **kwargs) return new_method -def check_gnn_list_or_ndarray(param, param_name): - """Check if the input parameter is list or numpy.ndarray.""" - - if isinstance(param, list): - for m in param: - if not isinstance(m, int): - raise TypeError( - "Each member in {0} should be of type int. Got {1}.".format(param_name, type(m))) - elif isinstance(param, np.ndarray): - if not param.dtype == np.int32: - raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( - param_name, param.dtype)) - else: - raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format( - param_name, type(param))) - - def check_gnn_get_all_nodes(method): """A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [node_type], _ = parse_user_args(method, *args, **kwargs) + type_check(node_type, (int,), "node_type") - # check node_type; required argument - check_type(param_dict.get("node_type"), 'node_type', int) - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -1148,13 +825,11 @@ def check_gnn_get_all_edges(method): """A wrapper that wrap a parameter checker to the GNN `get_all_edges` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [edge_type], _ = parse_user_args(method, *args, **kwargs) + type_check(edge_type, (int,), "edge_type") - # check node_type; required argument - check_type(param_dict.get("edge_type"), 'edge_type', int) - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -1163,13 +838,11 @@ def check_gnn_get_nodes_from_edges(method): """A wrapper that wrap a parameter checker to the GNN `get_nodes_from_edges` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [edge_list], _ = parse_user_args(method, *args, **kwargs) + check_gnn_list_or_ndarray(edge_list, "edge_list") - # check edge_list; required argument - check_gnn_list_or_ndarray(param_dict.get("edge_list"), 'edge_list') - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -1178,16 +851,13 @@ def check_gnn_get_all_neighbors(method): """A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [node_list, neighbour_type], _ = parse_user_args(method, *args, **kwargs) - # check node_list; required argument - check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list') + check_gnn_list_or_ndarray(node_list, 'node_list') + type_check(neighbour_type, (int,), "neighbour_type") - # check neighbor_type; required argument - check_type(param_dict.get("neighbor_type"), 'neighbor_type', int) - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -1196,21 +866,16 @@ def check_gnn_get_sampled_neighbors(method): """A wrapper that wrap a parameter checker to the GNN `get_sampled_neighbors` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [node_list, neighbor_nums, neighbor_types], _ = parse_user_args(method, *args, **kwargs) - # check node_list; required argument - check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list') + check_gnn_list_or_ndarray(node_list, 'node_list') - # check neighbor_nums; required argument - neighbor_nums = param_dict.get("neighbor_nums") check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums') if not neighbor_nums or len(neighbor_nums) > 6: raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format( 'neighbor_nums', len(neighbor_nums))) - # check neighbor_types; required argument - neighbor_types = param_dict.get("neighbor_types") check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types') if not neighbor_types or len(neighbor_types) > 6: raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format( @@ -1220,7 +885,7 @@ def check_gnn_get_sampled_neighbors(method): raise ValueError( "The number of members of neighbor_nums and neighbor_types is inconsistent") - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -1229,20 +894,14 @@ def check_gnn_get_neg_sampled_neighbors(method): """A wrapper that wrap a parameter checker to the GNN `get_neg_sampled_neighbors` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [node_list, neg_neighbor_num, neg_neighbor_type], _ = parse_user_args(method, *args, **kwargs) - # check node_list; required argument - check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list') + check_gnn_list_or_ndarray(node_list, 'node_list') + type_check(neg_neighbor_num, (int,), "neg_neighbor_num") + type_check(neg_neighbor_type, (int,), "neg_neighbor_type") - # check neg_neighbor_num; required argument - check_type(param_dict.get("neg_neighbor_num"), 'neg_neighbor_num', int) - - # check neg_neighbor_type; required argument - check_type(param_dict.get("neg_neighbor_type"), - 'neg_neighbor_type', int) - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -1251,20 +910,16 @@ def check_gnn_random_walk(method): """A wrapper that wrap a parameter checker to the GNN `random_walk` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [target_nodes, meta_path, step_home_param, step_away_param, default_node], _ = parse_user_args(method, *args, + **kwargs) + check_gnn_list_or_ndarray(target_nodes, 'target_nodes') + check_gnn_list_or_ndarray(meta_path, 'meta_path') + type_check(step_home_param, (float,), "step_home_param") + type_check(step_away_param, (float,), "step_away_param") + type_check(default_node, (int,), "default_node") - # check node_list; required argument - check_gnn_list_or_ndarray(param_dict.get("target_nodes"), 'target_nodes') - - # check meta_path; required argument - check_gnn_list_or_ndarray(param_dict.get("meta_path"), 'meta_path') - - check_type(param_dict.get("step_home_param"), 'step_home_param', float) - check_type(param_dict.get("step_away_param"), 'step_away_param', float) - check_type(param_dict.get("default_node"), 'default_node', int) - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -1272,8 +927,7 @@ def check_gnn_random_walk(method): def check_aligned_list(param, param_name, member_type): """Check whether the structure of each member of the list is the same.""" - if not isinstance(param, list): - raise TypeError("Parameter {0} is not a list".format(param_name)) + type_check(param, (list,), "param") if not param: raise TypeError( "Parameter {0} or its members are empty".format(param_name)) @@ -1282,6 +936,7 @@ def check_aligned_list(param, param_name, member_type): for member in param: if isinstance(member, list): check_aligned_list(member, param_name, member_type) + if member_have_list not in (None, True): raise TypeError("The type of each member of the parameter {0} is inconsistent".format( param_name)) @@ -1291,9 +946,7 @@ def check_aligned_list(param, param_name, member_type): member_have_list = True list_len = len(member) else: - if not isinstance(member, member_type): - raise TypeError("Each member in {0} should be of type int. Got {1}.".format( - param_name, type(member))) + type_check(member, (member_type,), param_name) if member_have_list not in (None, False): raise TypeError("The type of each member of the parameter {0} is inconsistent".format( param_name)) @@ -1304,26 +957,20 @@ def check_gnn_get_node_feature(method): """A wrapper that wrap a parameter checker to the GNN `get_node_feature` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [node_list, feature_types], _ = parse_user_args(method, *args, **kwargs) - # check node_list; required argument - node_list = param_dict.get("node_list") + type_check(node_list, (list, np.ndarray), "node_list") if isinstance(node_list, list): check_aligned_list(node_list, 'node_list', int) elif isinstance(node_list, np.ndarray): if not node_list.dtype == np.int32: raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( node_list, node_list.dtype)) - else: - raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format( - 'node_list', type(node_list))) - # check feature_types; required argument - check_gnn_list_or_ndarray(param_dict.get( - "feature_types"), 'feature_types') + check_gnn_list_or_ndarray(feature_types, 'feature_types') - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -1332,22 +979,17 @@ def check_numpyslicesdataset(method): """A wrapper that wrap a parameter checker to the original Dataset(NumpySlicesDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) - # check data; required argument - data = param_dict.get('data') - if not isinstance(data, (list, tuple, dict, np.ndarray)): - raise TypeError("Unsupported data type: {}, only support some common python data type, " - "like list, tuple, dict, and numpy array.".format(type(data))) - if isinstance(data, tuple) and not isinstance(data[0], (list, np.ndarray)): - raise TypeError("Unsupported data type: when input is tuple, only support some common python " - "data type, like tuple of lists and tuple of numpy arrays.") - if not data: - raise ValueError("Input data is empty.") + data = param_dict.get("data") + column_names = param_dict.get("column_names") + + type_check(data, (list, tuple, dict, np.ndarray), "data") + if isinstance(data, tuple): + type_check(data[0], (list, np.ndarray), "data[0]") # check column_names - column_names = param_dict.get('column_names') if column_names is not None: check_columns(column_names, "column_names") @@ -1368,6 +1010,6 @@ def check_numpyslicesdataset(method): raise ValueError("Num of input column names is {0}, but required is {1} as data is list." .format(column_num, 1)) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method diff --git a/mindspore/dataset/text/transforms.py b/mindspore/dataset/text/transforms.py index 8b0d47df253..f829e4ba737 100644 --- a/mindspore/dataset/text/transforms.py +++ b/mindspore/dataset/text/transforms.py @@ -98,7 +98,7 @@ class Ngram(cde.NgramOp): """ @check_ngram - def __init__(self, n, left_pad=None, right_pad=None, separator=None): + def __init__(self, n, left_pad=("", 0), right_pad=("", 0), separator=" "): super().__init__(ngrams=n, l_pad_len=left_pad[1], r_pad_len=right_pad[1], l_pad_token=left_pad[0], r_pad_token=right_pad[0], separator=separator) diff --git a/mindspore/dataset/text/utils.py b/mindspore/dataset/text/utils.py index 7347a4b8543..ef1d0e6fc5f 100644 --- a/mindspore/dataset/text/utils.py +++ b/mindspore/dataset/text/utils.py @@ -28,6 +28,7 @@ __all__ = [ "Vocab", "to_str", "to_bytes" ] + class Vocab(cde.Vocab): """ Vocab object that is used to lookup a word. @@ -38,7 +39,7 @@ class Vocab(cde.Vocab): @classmethod @check_from_dataset def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, - special_first=None): + special_first=True): """ Build a vocab from a dataset. @@ -62,13 +63,21 @@ class Vocab(cde.Vocab): special_tokens(list, optional): a list of strings, each one is a special token. for example special_tokens=["",""] (default=None, no special tokens will be added). special_first(bool, optional): whether special_tokens will be prepended/appended to vocab. If special_tokens - is specified and special_first is set to None, special_tokens will be prepended (default=None). + is specified and special_first is set to True, special_tokens will be prepended (default=True). Returns: Vocab, Vocab object built from dataset. """ vocab = Vocab() + if columns is None: + columns = [] + if not isinstance(columns, list): + columns = [columns] + if freq_range is None: + freq_range = (None, None) + if special_tokens is None: + special_tokens = [] root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first) for d in root.create_dict_iterator(): if d is not None: @@ -77,7 +86,7 @@ class Vocab(cde.Vocab): @classmethod @check_from_list - def from_list(cls, word_list, special_tokens=None, special_first=None): + def from_list(cls, word_list, special_tokens=None, special_first=True): """ Build a vocab object from a list of word. @@ -86,29 +95,33 @@ class Vocab(cde.Vocab): special_tokens(list, optional): a list of strings, each one is a special token. for example special_tokens=["",""] (default=None, no special tokens will be added). special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens - is specified and special_first is set to None, special_tokens will be prepended (default=None). + is specified and special_first is set to True, special_tokens will be prepended (default=True). """ - + if special_tokens is None: + special_tokens = [] return super().from_list(word_list, special_tokens, special_first) @classmethod @check_from_file - def from_file(cls, file_path, delimiter=None, vocab_size=None, special_tokens=None, special_first=None): + def from_file(cls, file_path, delimiter="", vocab_size=None, special_tokens=None, special_first=True): """ Build a vocab object from a list of word. Args: file_path (str): path to the file which contains the vocab list. delimiter (str, optional): a delimiter to break up each line in file, the first element is taken to be - the word (default=None). + the word (default=""). vocab_size (int, optional): number of words to read from file_path (default=None, all words are taken). special_tokens (list, optional): a list of strings, each one is a special token. for example special_tokens=["",""] (default=None, no special tokens will be added). special_first (bool, optional): whether special_tokens will be prepended/appended to vocab, - If special_tokens is specified and special_first is set to None, - special_tokens will be prepended (default=None). + If special_tokens is specified and special_first is set to True, + special_tokens will be prepended (default=True). """ - + if vocab_size is None: + vocab_size = -1 + if special_tokens is None: + special_tokens = [] return super().from_file(file_path, delimiter, vocab_size, special_tokens, special_first) @classmethod diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py index afab8665cde..39a0c4e6320 100644 --- a/mindspore/dataset/text/validators.py +++ b/mindspore/dataset/text/validators.py @@ -17,23 +17,22 @@ validators for text ops """ from functools import wraps - -import mindspore._c_dataengine as cde import mindspore.common.dtype as mstype +import mindspore._c_dataengine as cde from mindspore._c_expression import typing -from ..transforms.validators import check_uint32, check_pos_int64 + +from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, check_positive, \ + INT32_MAX, check_value def check_unique_list_of_words(words, arg_name): """Check that words is a list and each element is a str without any duplication""" - if not isinstance(words, list): - raise ValueError(arg_name + " needs to be a list of words of type string.") + type_check(words, (list,), arg_name) words_set = set() for word in words: - if not isinstance(word, str): - raise ValueError("each word in " + arg_name + " needs to be type str.") + type_check(word, (str,), arg_name) if word in words_set: raise ValueError(arg_name + " contains duplicate word: " + word + ".") words_set.add(word) @@ -45,21 +44,14 @@ def check_lookup(method): @wraps(method) def new_method(self, *args, **kwargs): - vocab, unknown = (list(args) + 2 * [None])[:2] - if "vocab" in kwargs: - vocab = kwargs.get("vocab") - if "unknown" in kwargs: - unknown = kwargs.get("unknown") + [vocab, unknown], _ = parse_user_args(method, *args, **kwargs) + if unknown is not None: - if not (isinstance(unknown, int) and unknown >= 0): - raise ValueError("unknown needs to be a non-negative integer.") + type_check(unknown, (int,), "unknown") + check_positive(unknown) + type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.") - if not isinstance(vocab, cde.Vocab): - raise ValueError("vocab is not an instance of cde.Vocab.") - - kwargs["vocab"] = vocab - kwargs["unknown"] = unknown - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -69,50 +61,15 @@ def check_from_file(method): @wraps(method) def new_method(self, *args, **kwargs): - file_path, delimiter, vocab_size, special_tokens, special_first = (list(args) + 5 * [None])[:5] - if "file_path" in kwargs: - file_path = kwargs.get("file_path") - if "delimiter" in kwargs: - delimiter = kwargs.get("delimiter") - if "vocab_size" in kwargs: - vocab_size = kwargs.get("vocab_size") - if "special_tokens" in kwargs: - special_tokens = kwargs.get("special_tokens") - if "special_first" in kwargs: - special_first = kwargs.get("special_first") - - if not isinstance(file_path, str): - raise ValueError("file_path needs to be str.") - - if delimiter is not None: - if not isinstance(delimiter, str): - raise ValueError("delimiter needs to be str.") - else: - delimiter = "" - if vocab_size is not None: - if not (isinstance(vocab_size, int) and vocab_size > 0): - raise ValueError("vocab size needs to be a positive integer.") - else: - vocab_size = -1 - - if special_first is None: - special_first = True - - if not isinstance(special_first, bool): - raise ValueError("special_first needs to be a boolean value") - - if special_tokens is None: - special_tokens = [] - + [file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args, + **kwargs) check_unique_list_of_words(special_tokens, "special_tokens") + type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"]) + if vocab_size is not None: + check_value(vocab_size, (-1, INT32_MAX), "vocab_size") + type_check(special_first, (bool,), special_first) - kwargs["file_path"] = file_path - kwargs["delimiter"] = delimiter - kwargs["vocab_size"] = vocab_size - kwargs["special_tokens"] = special_tokens - kwargs["special_first"] = special_first - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -122,33 +79,20 @@ def check_from_list(method): @wraps(method) def new_method(self, *args, **kwargs): - word_list, special_tokens, special_first = (list(args) + 3 * [None])[:3] - if "word_list" in kwargs: - word_list = kwargs.get("word_list") - if "special_tokens" in kwargs: - special_tokens = kwargs.get("special_tokens") - if "special_first" in kwargs: - special_first = kwargs.get("special_first") - if special_tokens is None: - special_tokens = [] + [word_list, special_tokens, special_first], _ = parse_user_args(method, *args, **kwargs) + word_set = check_unique_list_of_words(word_list, "word_list") - token_set = check_unique_list_of_words(special_tokens, "special_tokens") + if special_tokens is not None: + token_set = check_unique_list_of_words(special_tokens, "special_tokens") - intersect = word_set.intersection(token_set) + intersect = word_set.intersection(token_set) - if intersect != set(): - raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".") + if intersect != set(): + raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".") - if special_first is None: - special_first = True + type_check(special_first, (bool,), "special_first") - if not isinstance(special_first, bool): - raise ValueError("special_first needs to be a boolean value.") - - kwargs["word_list"] = word_list - kwargs["special_tokens"] = special_tokens - kwargs["special_first"] = special_first - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -158,18 +102,15 @@ def check_from_dict(method): @wraps(method) def new_method(self, *args, **kwargs): - word_dict, = (list(args) + [None])[:1] - if "word_dict" in kwargs: - word_dict = kwargs.get("word_dict") - if not isinstance(word_dict, dict): - raise ValueError("word_dict needs to be a list of word,id pairs.") + [word_dict], _ = parse_user_args(method, *args, **kwargs) + + type_check(word_dict, (dict,), "word_dict") + for word, word_id in word_dict.items(): - if not isinstance(word, str): - raise ValueError("Each word in word_dict needs to be type string.") - if not (isinstance(word_id, int) and word_id >= 0): - raise ValueError("Each word id needs to be positive integer.") - kwargs["word_dict"] = word_dict - return method(self, **kwargs) + type_check(word, (str,), "word") + type_check(word_id, (int,), "word_id") + check_value(word_id, (-1, INT32_MAX), "word_id") + return method(self, *args, **kwargs) return new_method @@ -179,23 +120,8 @@ def check_jieba_init(method): @wraps(method) def new_method(self, *args, **kwargs): - hmm_path, mp_path, model = (list(args) + 3 * [None])[:3] - - if "hmm_path" in kwargs: - hmm_path = kwargs.get("hmm_path") - if "mp_path" in kwargs: - mp_path = kwargs.get("mp_path") - if hmm_path is None: - raise ValueError( - "The dict of HMMSegment in cppjieba is not provided.") - kwargs["hmm_path"] = hmm_path - if mp_path is None: - raise ValueError( - "The dict of MPSegment in cppjieba is not provided.") - kwargs["mp_path"] = mp_path - if model is not None: - kwargs["model"] = model - return method(self, **kwargs) + parse_user_args(method, *args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -205,19 +131,12 @@ def check_jieba_add_word(method): @wraps(method) def new_method(self, *args, **kwargs): - word, freq = (list(args) + 2 * [None])[:2] - - if "word" in kwargs: - word = kwargs.get("word") - if "freq" in kwargs: - freq = kwargs.get("freq") + [word, freq], _ = parse_user_args(method, *args, **kwargs) if word is None: raise ValueError("word is not provided.") - kwargs["word"] = word if freq is not None: check_uint32(freq) - kwargs["freq"] = freq - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -227,13 +146,8 @@ def check_jieba_add_dict(method): @wraps(method) def new_method(self, *args, **kwargs): - user_dict = (list(args) + [None])[0] - if "user_dict" in kwargs: - user_dict = kwargs.get("user_dict") - if user_dict is None: - raise ValueError("user_dict is not provided.") - kwargs["user_dict"] = user_dict - return method(self, **kwargs) + parse_user_args(method, *args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -244,69 +158,39 @@ def check_from_dataset(method): @wraps(method) def new_method(self, *args, **kwargs): - dataset, columns, freq_range, top_k, special_tokens, special_first = (list(args) + 6 * [None])[:6] - if "dataset" in kwargs: - dataset = kwargs.get("dataset") - if "columns" in kwargs: - columns = kwargs.get("columns") - if "freq_range" in kwargs: - freq_range = kwargs.get("freq_range") - if "top_k" in kwargs: - top_k = kwargs.get("top_k") - if "special_tokens" in kwargs: - special_tokens = kwargs.get("special_tokens") - if "special_first" in kwargs: - special_first = kwargs.get("special_first") + [_, columns, freq_range, top_k, special_tokens, special_first], _ = parse_user_args(method, *args, + **kwargs) + if columns is not None: + if not isinstance(columns, list): + columns = [columns] + col_names = ["col_{0}".format(i) for i in range(len(columns))] + type_check_list(columns, (str,), col_names) - if columns is None: - columns = [] + if freq_range is not None: + type_check(freq_range, (tuple,), "freq_range") - if not isinstance(columns, list): - columns = [columns] + if len(freq_range) != 2: + raise ValueError("freq_range needs to be a tuple of 2 integers or an int and a None.") - for column in columns: - if not isinstance(column, str): - raise ValueError("columns need to be a list of strings.") + for num in freq_range: + if num is not None and (not isinstance(num, int)): + raise ValueError( + "freq_range needs to be either None or a tuple of 2 integers or an int and a None.") - if freq_range is None: - freq_range = (None, None) + if isinstance(freq_range[0], int) and isinstance(freq_range[1], int): + if freq_range[0] > freq_range[1] or freq_range[0] < 0: + raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).") - if not isinstance(freq_range, tuple) or len(freq_range) != 2: - raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.") + type_check(top_k, (int, type(None)), "top_k") - for num in freq_range: - if num is not None and (not isinstance(num, int)): - raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.") + if isinstance(top_k, int): + check_value(top_k, (0, INT32_MAX), "top_k") + type_check(special_first, (bool,), "special_first") - if isinstance(freq_range[0], int) and isinstance(freq_range[1], int): - if freq_range[0] > freq_range[1] or freq_range[0] < 0: - raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).") + if special_tokens is not None: + check_unique_list_of_words(special_tokens, "special_tokens") - if top_k is not None and (not isinstance(top_k, int)): - raise ValueError("top_k needs to be a positive integer.") - - if isinstance(top_k, int) and top_k <= 0: - raise ValueError("top_k needs to be a positive integer.") - - if special_first is None: - special_first = True - - if special_tokens is None: - special_tokens = [] - - if not isinstance(special_first, bool): - raise ValueError("special_first needs to be a boolean value.") - - check_unique_list_of_words(special_tokens, "special_tokens") - - kwargs["dataset"] = dataset - kwargs["columns"] = columns - kwargs["freq_range"] = freq_range - kwargs["top_k"] = top_k - kwargs["special_tokens"] = special_tokens - kwargs["special_first"] = special_first - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -316,15 +200,7 @@ def check_ngram(method): @wraps(method) def new_method(self, *args, **kwargs): - n, left_pad, right_pad, separator = (list(args) + 4 * [None])[:4] - if "n" in kwargs: - n = kwargs.get("n") - if "left_pad" in kwargs: - left_pad = kwargs.get("left_pad") - if "right_pad" in kwargs: - right_pad = kwargs.get("right_pad") - if "separator" in kwargs: - separator = kwargs.get("separator") + [n, left_pad, right_pad, separator], _ = parse_user_args(method, *args, **kwargs) if isinstance(n, int): n = [n] @@ -332,15 +208,9 @@ def check_ngram(method): if not (isinstance(n, list) and n != []): raise ValueError("n needs to be a non-empty list of positive integers.") - for gram in n: - if not (isinstance(gram, int) and gram > 0): - raise ValueError("n in ngram needs to be a positive number.") - - if left_pad is None: - left_pad = ("", 0) - - if right_pad is None: - right_pad = ("", 0) + for i, gram in enumerate(n): + type_check(gram, (int,), "gram[{0}]".format(i)) + check_value(gram, (0, INT32_MAX), "gram_{}".format(i)) if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance( left_pad[1], int)): @@ -353,11 +223,7 @@ def check_ngram(method): if not (left_pad[1] >= 0 and right_pad[1] >= 0): raise ValueError("padding width need to be positive numbers.") - if separator is None: - separator = " " - - if not isinstance(separator, str): - raise ValueError("separator needs to be a string.") + type_check(separator, (str,), "separator") kwargs["n"] = n kwargs["left_pad"] = left_pad @@ -374,16 +240,8 @@ def check_pair_truncate(method): @wraps(method) def new_method(self, *args, **kwargs): - max_length = (list(args) + [None])[0] - if "max_length" in kwargs: - max_length = kwargs.get("max_length") - if max_length is None: - raise ValueError("max_length is not provided.") - - check_pos_int64(max_length) - kwargs["max_length"] = max_length - - return method(self, **kwargs) + parse_user_args(method, *args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -393,22 +251,13 @@ def check_to_number(method): @wraps(method) def new_method(self, *args, **kwargs): - data_type = (list(args) + [None])[0] - if "data_type" in kwargs: - data_type = kwargs.get("data_type") - - if data_type is None: - raise ValueError("data_type is a mandatory parameter but was not provided.") - - if not isinstance(data_type, typing.Type): - raise TypeError("data_type is not a MindSpore data type.") + [data_type], _ = parse_user_args(method, *args, **kwargs) + type_check(data_type, (typing.Type,), "data_type") if data_type not in mstype.number_type: raise TypeError("data_type is not numeric data type.") - kwargs["data_type"] = data_type - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -418,18 +267,11 @@ def check_python_tokenizer(method): @wraps(method) def new_method(self, *args, **kwargs): - tokenizer = (list(args) + [None])[0] - if "tokenizer" in kwargs: - tokenizer = kwargs.get("tokenizer") - - if tokenizer is None: - raise ValueError("tokenizer is a mandatory parameter.") + [tokenizer], _ = parse_user_args(method, *args, **kwargs) if not callable(tokenizer): raise TypeError("tokenizer is not a callable python function") - kwargs["tokenizer"] = tokenizer - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method diff --git a/mindspore/dataset/transforms/validators.py b/mindspore/dataset/transforms/validators.py index 6b5760e0c5a..9fe0fa5f106 100644 --- a/mindspore/dataset/transforms/validators.py +++ b/mindspore/dataset/transforms/validators.py @@ -18,6 +18,7 @@ from functools import wraps import numpy as np from mindspore._c_expression import typing +from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive # POS_INT_MIN is used to limit values from starting from 0 POS_INT_MIN = 1 @@ -37,106 +38,33 @@ DOUBLE_MAX_INTEGER = 9007199254740992 DOUBLE_MIN_INTEGER = -9007199254740992 -def check_type(value, valid_type): - if not isinstance(value, valid_type): - raise ValueError("Wrong input type") - - -def check_value(value, valid_range): - if value < valid_range[0] or value > valid_range[1]: - raise ValueError("Input is not within the required range") - - -def check_range(values, valid_range): - if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]: - raise ValueError("Input range is not valid") - - -def check_positive(value): - if value <= 0: - raise ValueError("Input must greater than 0") - - -def check_positive_float(value, valid_max=None): - if value <= 0 or not isinstance(value, float) or (valid_max is not None and value > valid_max): - raise ValueError("Input need to be a valid positive float.") - - -def check_bool(value): - if not isinstance(value, bool): - raise ValueError("Value needs to be a boolean.") - - -def check_2tuple(value): - if not (isinstance(value, tuple) and len(value) == 2): - raise ValueError("Value needs to be a 2-tuple.") - - -def check_list(value): - if not isinstance(value, list): - raise ValueError("The input needs to be a list.") - - -def check_uint8(value): - if not isinstance(value, int): - raise ValueError("The input needs to be a integer") - check_value(value, [UINT8_MIN, UINT8_MAX]) - - -def check_uint32(value): - if not isinstance(value, int): - raise ValueError("The input needs to be a integer") - check_value(value, [UINT32_MIN, UINT32_MAX]) - - -def check_pos_int32(value): - """Checks for int values starting from 1""" - if not isinstance(value, int): - raise ValueError("The input needs to be a integer") - check_value(value, [POS_INT_MIN, INT32_MAX]) - - -def check_uint64(value): - if not isinstance(value, int): - raise ValueError("The input needs to be a integer") - check_value(value, [UINT64_MIN, UINT64_MAX]) - - -def check_pos_int64(value): - if not isinstance(value, int): - raise ValueError("The input needs to be a integer") - check_value(value, [UINT64_MIN, INT64_MAX]) - - -def check_pos_float32(value): - check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER]) - - -def check_pos_float64(value): - check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER]) - - -def check_one_hot_op(method): - """Wrapper method to check the parameters of one hot op.""" +def check_fill_value(method): + """Wrapper method to check the parameters of fill_value.""" @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 2 * [None])[:2] - num_classes, smoothing_rate = args - if "num_classes" in kwargs: - num_classes = kwargs.get("num_classes") - if "smoothing_rate" in kwargs: - smoothing_rate = kwargs.get("smoothing_rate") + [fill_value], _ = parse_user_args(method, *args, **kwargs) + type_check(fill_value, (str, float, bool, int, bytes), "fill_value") + + return method(self, *args, **kwargs) + + return new_method + + +def check_one_hot_op(method): + """Wrapper method to check the parameters of one_hot_op.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [num_classes, smoothing_rate], _ = parse_user_args(method, *args, **kwargs) + + type_check(num_classes, (int,), "num_classes") + check_positive(num_classes) - if num_classes is None: - raise ValueError("num_classes") - check_pos_int32(num_classes) - kwargs["num_classes"] = num_classes if smoothing_rate is not None: - check_value(smoothing_rate, [0., 1.]) - kwargs["smoothing_rate"] = smoothing_rate + check_value(smoothing_rate, [0., 1.], "smoothing_rate") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -146,35 +74,12 @@ def check_num_classes(method): @wraps(method) def new_method(self, *args, **kwargs): - num_classes = (list(args) + [None])[0] - if "num_classes" in kwargs: - num_classes = kwargs.get("num_classes") - if num_classes is None: - raise ValueError("num_classes is not provided.") + [num_classes], _ = parse_user_args(method, *args, **kwargs) - check_pos_int32(num_classes) - kwargs["num_classes"] = num_classes + type_check(num_classes, (int,), "num_classes") + check_positive(num_classes) - return method(self, **kwargs) - - return new_method - - -def check_fill_value(method): - """Wrapper method to check the parameters of fill value.""" - - @wraps(method) - def new_method(self, *args, **kwargs): - fill_value = (list(args) + [None])[0] - if "fill_value" in kwargs: - fill_value = kwargs.get("fill_value") - if fill_value is None: - raise ValueError("fill_value is not provided.") - if not isinstance(fill_value, (str, float, bool, int, bytes)): - raise TypeError("fill_value must be either a primitive python str, float, bool, bytes or int") - kwargs["fill_value"] = fill_value - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -184,17 +89,11 @@ def check_de_type(method): @wraps(method) def new_method(self, *args, **kwargs): - data_type = (list(args) + [None])[0] - if "data_type" in kwargs: - data_type = kwargs.get("data_type") + [data_type], _ = parse_user_args(method, *args, **kwargs) - if data_type is None: - raise ValueError("data_type is not provided.") - if not isinstance(data_type, typing.Type): - raise TypeError("data_type is not a MindSpore data type.") - kwargs["data_type"] = data_type + type_check(data_type, (typing.Type,), "data_type") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -204,13 +103,11 @@ def check_slice_op(method): @wraps(method) def new_method(self, *args): - for i, arg in enumerate(args): - if arg is not None and arg is not Ellipsis and not isinstance(arg, (int, slice, list)): - raise TypeError("Indexing of dim " + str(i) + "is not of valid type") + for _, arg in enumerate(args): + type_check(arg, (int, slice, list, type(None), type(Ellipsis)), "arg") if isinstance(arg, list): for a in arg: - if not isinstance(a, int): - raise TypeError("Index " + a + " is not an int") + type_check(a, (int,), "a") return method(self, *args) return new_method @@ -221,36 +118,14 @@ def check_mask_op(method): @wraps(method) def new_method(self, *args, **kwargs): - operator, constant, dtype = (list(args) + 3 * [None])[:3] - if "operator" in kwargs: - operator = kwargs.get("operator") - if "constant" in kwargs: - constant = kwargs.get("constant") - if "dtype" in kwargs: - dtype = kwargs.get("dtype") - - if operator is None: - raise ValueError("operator is not provided.") - - if constant is None: - raise ValueError("constant is not provided.") + [operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs) from .c_transforms import Relational - if not isinstance(operator, Relational): - raise TypeError("operator is not a Relational operator enum.") + type_check(operator, (Relational,), "operator") + type_check(constant, (str, float, bool, int, bytes), "constant") + type_check(dtype, (typing.Type,), "dtype") - if not isinstance(constant, (str, float, bool, int, bytes)): - raise TypeError("constant must be either a primitive python str, float, bool, bytes or int") - - if dtype is not None: - if not isinstance(dtype, typing.Type): - raise TypeError("dtype is not a MindSpore data type.") - kwargs["dtype"] = dtype - - kwargs["operator"] = operator - kwargs["constant"] = constant - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -260,22 +135,12 @@ def check_pad_end(method): @wraps(method) def new_method(self, *args, **kwargs): - pad_shape, pad_value = (list(args) + 2 * [None])[:2] - if "pad_shape" in kwargs: - pad_shape = kwargs.get("pad_shape") - if "pad_value" in kwargs: - pad_value = kwargs.get("pad_value") - if pad_shape is None: - raise ValueError("pad_shape is not provided.") + [pad_shape, pad_value], _ = parse_user_args(method, *args, **kwargs) if pad_value is not None: - if not isinstance(pad_value, (str, float, bool, int, bytes)): - raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes") - kwargs["pad_value"] = pad_value - - if not isinstance(pad_shape, list): - raise TypeError("pad_shape must be a list") + type_check(pad_value, (str, float, bool, int, bytes), "pad_value") + type_check(pad_shape, (list,), "pad_end") for dim in pad_shape: if dim is not None: @@ -284,9 +149,7 @@ def check_pad_end(method): else: raise TypeError("a value in the list is not an integer.") - kwargs["pad_shape"] = pad_shape - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -296,31 +159,24 @@ def check_concat_type(method): @wraps(method) def new_method(self, *args, **kwargs): - axis, prepend, append = (list(args) + 3 * [None])[:3] - if "prepend" in kwargs: - prepend = kwargs.get("prepend") - if "append" in kwargs: - append = kwargs.get("append") - if "axis" in kwargs: - axis = kwargs.get("axis") + + [axis, prepend, append], _ = parse_user_args(method, *args, **kwargs) if axis is not None: - if not isinstance(axis, int): - raise TypeError("axis type is not valid, must be an integer.") + type_check(axis, (int,), "axis") if axis not in (0, -1): raise ValueError("only 1D concatenation supported.") - kwargs["axis"] = axis if prepend is not None: - if not isinstance(prepend, (type(None), np.ndarray)): - raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.") - kwargs["prepend"] = prepend + type_check(prepend, (np.ndarray,), "prepend") + if len(prepend.shape) != 1: + raise ValueError("can only prepend 1D arrays.") if append is not None: - if not isinstance(append, (type(None), np.ndarray)): - raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.") - kwargs["append"] = append + type_check(append, (np.ndarray,), "append") + if len(append.shape) != 1: + raise ValueError("can only append 1D arrays.") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index 43ac037541e..8e3b7c72141 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -40,12 +40,14 @@ Examples: >>> dataset = dataset.map(input_columns="image", operations=transforms_list) >>> dataset = dataset.map(input_columns="label", operations=onehot_op) """ +import numbers import mindspore._c_dataengine as cde from .utils import Inter, Border from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ - check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ - check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp + check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \ + check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \ + FLOAT_MAX_INTEGER DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, @@ -57,6 +59,18 @@ DE_C_BORDER_TYPE = {Border.CONSTANT: cde.BorderType.DE_BORDER_CONSTANT, Border.SYMMETRIC: cde.BorderType.DE_BORDER_SYMMETRIC} +def parse_padding(padding): + if isinstance(padding, numbers.Number): + padding = [padding] * 4 + if len(padding) == 2: + left = right = padding[0] + top = bottom = padding[1] + padding = (left, top, right, bottom,) + if isinstance(padding, list): + padding = tuple(padding) + return padding + + class Decode(cde.DecodeOp): """ Decode the input image in RGB mode. @@ -136,16 +150,22 @@ class RandomCrop(cde.RandomCropOp): @check_random_crop def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): + if isinstance(size, int): + size = (size, size) + if padding is None: + padding = (0, 0, 0, 0) + else: + padding = parse_padding(padding) + if isinstance(fill_value, int): # temporary fix + fill_value = tuple([fill_value] * 3) + border_type = DE_C_BORDER_TYPE[padding_mode] + self.size = size self.padding = padding self.pad_if_needed = pad_if_needed self.fill_value = fill_value self.padding_mode = padding_mode.value - if padding is None: - padding = (0, 0, 0, 0) - if isinstance(fill_value, int): # temporary fix - fill_value = tuple([fill_value] * 3) - border_type = DE_C_BORDER_TYPE[padding_mode] + super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value) @@ -184,16 +204,23 @@ class RandomCropWithBBox(cde.RandomCropWithBBoxOp): @check_random_crop def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): + if isinstance(size, int): + size = (size, size) + if padding is None: + padding = (0, 0, 0, 0) + else: + padding = parse_padding(padding) + + if isinstance(fill_value, int): # temporary fix + fill_value = tuple([fill_value] * 3) + border_type = DE_C_BORDER_TYPE[padding_mode] + self.size = size self.padding = padding self.pad_if_needed = pad_if_needed self.fill_value = fill_value self.padding_mode = padding_mode.value - if padding is None: - padding = (0, 0, 0, 0) - if isinstance(fill_value, int): # temporary fix - fill_value = tuple([fill_value] * 3) - border_type = DE_C_BORDER_TYPE[padding_mode] + super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value) @@ -292,6 +319,8 @@ class Resize(cde.ResizeOp): @check_resize_interpolation def __init__(self, size, interpolation=Inter.LINEAR): + if isinstance(size, int): + size = (size, size) self.size = size self.interpolation = interpolation interpoltn = DE_C_INTER_MODE[interpolation] @@ -359,6 +388,8 @@ class RandomResizedCropWithBBox(cde.RandomCropAndResizeWithBBoxOp): @check_random_resize_crop def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Inter.BILINEAR, max_attempts=10): + if isinstance(size, int): + size = (size, size) self.size = size self.scale = scale self.ratio = ratio @@ -396,6 +427,8 @@ class RandomResizedCrop(cde.RandomCropAndResizeOp): @check_random_resize_crop def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Inter.BILINEAR, max_attempts=10): + if isinstance(size, int): + size = (size, size) self.size = size self.scale = scale self.ratio = ratio @@ -417,6 +450,8 @@ class CenterCrop(cde.CenterCropOp): @check_crop def __init__(self, size): + if isinstance(size, int): + size = (size, size) self.size = size super().__init__(*size) @@ -442,12 +477,26 @@ class RandomColorAdjust(cde.RandomColorAdjustOp): @check_random_color_adjust def __init__(self, brightness=(1, 1), contrast=(1, 1), saturation=(1, 1), hue=(0, 0)): + brightness = self.expand_values(brightness) + contrast = self.expand_values(contrast) + saturation = self.expand_values(saturation) + hue = self.expand_values(hue, center=0, bound=(-0.5, 0.5), non_negative=False) + self.brightness = brightness self.contrast = contrast self.saturation = saturation self.hue = hue + super().__init__(*brightness, *contrast, *saturation, *hue) + def expand_values(self, value, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True): + if isinstance(value, numbers.Number): + value = [center - value, center + value] + if non_negative: + value[0] = max(0, value[0]) + check_range(value, bound) + return (value[0], value[1]) + class RandomRotation(cde.RandomRotationOp): """ @@ -485,6 +534,8 @@ class RandomRotation(cde.RandomRotationOp): self.expand = expand self.center = center self.fill_value = fill_value + if isinstance(degrees, numbers.Number): + degrees = (-degrees, degrees) if center is None: center = (-1, -1) if isinstance(fill_value, int): # temporary fix @@ -584,6 +635,8 @@ class RandomCropDecodeResize(cde.RandomCropDecodeResizeOp): @check_random_resize_crop def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Inter.BILINEAR, max_attempts=10): + if isinstance(size, int): + size = (size, size) self.size = size self.scale = scale self.ratio = ratio @@ -623,12 +676,14 @@ class Pad(cde.PadOp): @check_pad def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT): - self.padding = padding - self.fill_value = fill_value - self.padding_mode = padding_mode + padding = parse_padding(padding) if isinstance(fill_value, int): # temporary fix fill_value = tuple([fill_value] * 3) padding_mode = DE_C_BORDER_TYPE[padding_mode] + + self.padding = padding + self.fill_value = fill_value + self.padding_mode = padding_mode super().__init__(*padding, padding_mode, *fill_value) diff --git a/mindspore/dataset/transforms/vision/py_transforms.py b/mindspore/dataset/transforms/vision/py_transforms.py index b252c3434b9..3bfd6b0644f 100644 --- a/mindspore/dataset/transforms/vision/py_transforms.py +++ b/mindspore/dataset/transforms/vision/py_transforms.py @@ -28,6 +28,7 @@ import numpy as np from PIL import Image from . import py_transforms_util as util +from .c_transforms import parse_padding from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \ check_transforms_list, check_random_apply, check_ten_crop, check_num_channels, check_pad, \ @@ -295,6 +296,10 @@ class RandomCrop: @check_random_crop def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): + if padding is None: + padding = (0, 0, 0, 0) + else: + padding = parse_padding(padding) self.size = size self.padding = padding self.pad_if_needed = pad_if_needed @@ -753,6 +758,8 @@ class TenCrop: @check_ten_crop def __init__(self, size, use_vertical_flip=False): + if isinstance(size, int): + size = (size, size) self.size = size self.use_vertical_flip = use_vertical_flip @@ -877,6 +884,8 @@ class Pad: @check_pad def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT): + parse_padding(padding) + self.padding = padding self.fill_value = fill_value self.padding_mode = DE_PY_BORDER_TYPE[padding_mode] @@ -1129,56 +1138,23 @@ class RandomAffine: def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Inter.NEAREST, fill_value=0): # Parameter checking # rotation - if isinstance(degrees, numbers.Number): - if degrees < 0: - raise ValueError("If degrees is a single number, it must be positive.") - self.degrees = (-degrees, degrees) - elif isinstance(degrees, (tuple, list)) and len(degrees) == 2: - self.degrees = degrees - else: - raise TypeError("If degrees is a list or tuple, it must be of length 2.") - - # translation - if translate is not None: - if isinstance(translate, (tuple, list)) and len(translate) == 2: - for t in translate: - if t < 0.0 or t > 1.0: - raise ValueError("translation values should be between 0 and 1") - else: - raise TypeError("translate should be a list or tuple of length 2.") - self.translate = translate - - # scale - if scale is not None: - if isinstance(scale, (tuple, list)) and len(scale) == 2: - for s in scale: - if s <= 0: - raise ValueError("scale values should be positive") - else: - raise TypeError("scale should be a list or tuple of length 2.") - self.scale_ranges = scale - - # shear if shear is not None: if isinstance(shear, numbers.Number): - if shear < 0: - raise ValueError("If shear is a single number, it must be positive.") - self.shear = (-1 * shear, shear) - elif isinstance(shear, (tuple, list)) and (len(shear) == 2 or len(shear) == 4): - # X-Axis shear with [min, max] - if len(shear) == 2: - self.shear = [shear[0], shear[1], 0., 0.] - elif len(shear) == 4: - self.shear = [s for s in shear] + shear = (-1 * shear, shear) else: - raise TypeError("shear should be a list or tuple and it must be of length 2 or 4.") - else: - self.shear = shear + if len(shear) == 2: + shear = [shear[0], shear[1], 0., 0.] + elif len(shear) == 4: + shear = [s for s in shear] - # resample + if isinstance(degrees, numbers.Number): + degrees = (-degrees, degrees) + + self.degrees = degrees + self.translate = translate + self.scale_ranges = scale + self.shear = shear self.resample = DE_PY_INTER_MODE[resample] - - # fill_value self.fill_value = fill_value def __call__(self, img): diff --git a/mindspore/dataset/transforms/vision/validators.py b/mindspore/dataset/transforms/vision/validators.py index b49116349bb..078845227df 100644 --- a/mindspore/dataset/transforms/vision/validators.py +++ b/mindspore/dataset/transforms/vision/validators.py @@ -16,47 +16,35 @@ """ import numbers from functools import wraps - +import numpy as np from mindspore._c_dataengine import TensorOp from .utils import Inter, Border -from ...transforms.validators import check_pos_int32, check_pos_float32, check_value, check_uint8, FLOAT_MAX_INTEGER, \ - check_bool, check_2tuple, check_range, check_list, check_type, check_positive, INT32_MAX - - -def check_inter_mode(mode): - if not isinstance(mode, Inter): - raise ValueError("Invalid interpolation mode.") - - -def check_border_type(mode): - if not isinstance(mode, Border): - raise ValueError("Invalid padding mode.") +from ...core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \ + check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list def check_crop_size(size): """Wrapper method to check the parameters of crop size.""" + type_check(size, (int, list, tuple), "size") if isinstance(size, int): - size = (size, size) + check_value(size, (1, FLOAT_MAX_INTEGER)) elif isinstance(size, (tuple, list)) and len(size) == 2: - size = size + for value in size: + check_value(value, (1, FLOAT_MAX_INTEGER)) else: raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.") - for value in size: - check_pos_int32(value) - return size def check_resize_size(size): """Wrapper method to check the parameters of resize.""" if isinstance(size, int): - check_pos_int32(size) + check_value(size, (1, FLOAT_MAX_INTEGER)) elif isinstance(size, (tuple, list)) and len(size) == 2: - for value in size: - check_value(value, (1, INT32_MAX)) + for i, value in enumerate(size): + check_value(value, (1, INT32_MAX), "size at dim {0}".format(i)) else: raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.") - return size def check_normalize_c_param(mean, std): @@ -72,9 +60,9 @@ def check_normalize_py_param(mean, std): if len(mean) != len(std): raise ValueError("Length of mean and std must be equal") for mean_value in mean: - check_value(mean_value, [0., 1.]) + check_value(mean_value, [0., 1.], "mean_value") for std_value in std: - check_value(std_value, [0., 1.]) + check_value(std_value, [0., 1.], "std_value") def check_fill_value(fill_value): @@ -85,66 +73,37 @@ def check_fill_value(fill_value): check_uint8(value) else: raise TypeError("fill_value should be a single integer or a 3-tuple.") - return fill_value def check_padding(padding): """Parsing the padding arguments and check if it is legal.""" - if isinstance(padding, numbers.Number): - top = bottom = left = right = padding - - elif isinstance(padding, (tuple, list)): - if len(padding) == 2: - left = right = padding[0] - top = bottom = padding[1] - elif len(padding) == 4: - left = padding[0] - top = padding[1] - right = padding[2] - bottom = padding[3] - else: + type_check(padding, (tuple, list, numbers.Number), "padding") + if isinstance(padding, (tuple, list)): + if len(padding) not in (2, 4): raise ValueError("The size of the padding list or tuple should be 2 or 4.") - else: - raise TypeError("Padding can be any of: a number, a tuple or list of size 2 or 4.") - if not (isinstance(left, int) and isinstance(top, int) and isinstance(right, int) and isinstance(bottom, int)): - raise TypeError("Padding value should be integer.") - if left < 0 or top < 0 or right < 0 or bottom < 0: - raise ValueError("Padding value could not be negative.") - return left, top, right, bottom + for i, pad_value in enumerate(padding): + type_check(pad_value, (int,), "padding[{}]".format(i)) + check_value(pad_value, (0, INT32_MAX), "pad_value") def check_degrees(degrees): """Check if the degrees is legal.""" + type_check(degrees, (numbers.Number, list, tuple), "degrees") if isinstance(degrees, numbers.Number): - if degrees < 0: - raise ValueError("If degrees is a single number, it cannot be negative.") - degrees = (-degrees, degrees) + check_value(degrees, (0, float("inf")), "degrees") elif isinstance(degrees, (list, tuple)): if len(degrees) != 2: raise TypeError("If degrees is a sequence, the length must be 2.") - else: - raise TypeError("Degrees must be a single non-negative number or a sequence") - return degrees def check_random_color_adjust_param(value, input_name, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True): """Check the parameters in random color adjust operation.""" + type_check(value, (numbers.Number, list, tuple), input_name) if isinstance(value, numbers.Number): if value < 0: raise ValueError("The input value of {} cannot be negative.".format(input_name)) - # convert value into a range - value = [center - value, center + value] - if non_negative: - value[0] = max(0, value[0]) elif isinstance(value, (list, tuple)) and len(value) == 2: - if not bound[0] <= value[0] <= value[1] <= bound[1]: - raise ValueError("Please check your value range of {} is valid and " - "within the bound {}".format(input_name, bound)) - else: - raise TypeError("Input of {} should be either a single value, or a list/tuple of " - "length 2.".format(input_name)) - factor = (value[0], value[1]) - return factor + check_range(value, bound) def check_erasing_value(value): @@ -159,15 +118,10 @@ def check_crop(method): @wraps(method) def new_method(self, *args, **kwargs): - size = (list(args) + [None])[0] - if "size" in kwargs: - size = kwargs.get("size") - if size is None: - raise ValueError("size is not provided.") - size = check_crop_size(size) - kwargs["size"] = size + [size], _ = parse_user_args(method, *args, **kwargs) + check_crop_size(size) - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -177,23 +131,12 @@ def check_resize_interpolation(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 2 * [None])[:2] - size, interpolation = args - if "size" in kwargs: - size = kwargs.get("size") - if "interpolation" in kwargs: - interpolation = kwargs.get("interpolation") - - if size is None: - raise ValueError("size is not provided.") - size = check_resize_size(size) - kwargs["size"] = size - + [size, interpolation], _ = parse_user_args(method, *args, **kwargs) + check_resize_size(size) if interpolation is not None: - check_inter_mode(interpolation) - kwargs["interpolation"] = interpolation + type_check(interpolation, (Inter,), "interpolation") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -203,16 +146,10 @@ def check_resize(method): @wraps(method) def new_method(self, *args, **kwargs): - size = (list(args) + [None])[0] - if "size" in kwargs: - size = kwargs.get("size") + [size], _ = parse_user_args(method, *args, **kwargs) + check_resize_size(size) - if size is None: - raise ValueError("size is not provided.") - size = check_resize_size(size) - kwargs["size"] = size - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -222,39 +159,20 @@ def check_random_resize_crop(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 5 * [None])[:5] - size, scale, ratio, interpolation, max_attempts = args - if "size" in kwargs: - size = kwargs.get("size") - if "scale" in kwargs: - scale = kwargs.get("scale") - if "ratio" in kwargs: - ratio = kwargs.get("ratio") - if "interpolation" in kwargs: - interpolation = kwargs.get("interpolation") - if "max_attempts" in kwargs: - max_attempts = kwargs.get("max_attempts") - - if size is None: - raise ValueError("size is not provided.") - size = check_crop_size(size) - kwargs["size"] = size + [size, scale, ratio, interpolation, max_attempts], _ = parse_user_args(method, *args, **kwargs) + check_crop_size(size) if scale is not None: check_range(scale, [0, FLOAT_MAX_INTEGER]) - kwargs["scale"] = scale if ratio is not None: check_range(ratio, [0, FLOAT_MAX_INTEGER]) - check_positive(ratio[0]) - kwargs["ratio"] = ratio + check_positive(ratio[0], "ratio[0]") if interpolation is not None: - check_inter_mode(interpolation) - kwargs["interpolation"] = interpolation + type_check(interpolation, (Inter,), "interpolation") if max_attempts is not None: - check_pos_int32(max_attempts) - kwargs["max_attempts"] = max_attempts + check_value(max_attempts, (1, FLOAT_MAX_INTEGER)) - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -264,14 +182,11 @@ def check_prob(method): @wraps(method) def new_method(self, *args, **kwargs): - prob = (list(args) + [None])[0] - if "prob" in kwargs: - prob = kwargs.get("prob") - if prob is not None: - check_value(prob, [0., 1.]) - kwargs["prob"] = prob + [prob], _ = parse_user_args(method, *args, **kwargs) + type_check(prob, (float, int,), "prob") + check_value(prob, [0., 1.], "prob") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -281,22 +196,10 @@ def check_normalize_c(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 2 * [None])[:2] - mean, std = args - if "mean" in kwargs: - mean = kwargs.get("mean") - if "std" in kwargs: - std = kwargs.get("std") - - if mean is None: - raise ValueError("mean is not provided.") - if std is None: - raise ValueError("std is not provided.") + [mean, std], _ = parse_user_args(method, *args, **kwargs) check_normalize_c_param(mean, std) - kwargs["mean"] = mean - kwargs["std"] = std - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -306,22 +209,10 @@ def check_normalize_py(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 2 * [None])[:2] - mean, std = args - if "mean" in kwargs: - mean = kwargs.get("mean") - if "std" in kwargs: - std = kwargs.get("std") - - if mean is None: - raise ValueError("mean is not provided.") - if std is None: - raise ValueError("std is not provided.") + [mean, std], _ = parse_user_args(method, *args, **kwargs) check_normalize_py_param(mean, std) - kwargs["mean"] = mean - kwargs["std"] = std - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -331,38 +222,17 @@ def check_random_crop(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 5 * [None])[:5] - size, padding, pad_if_needed, fill_value, padding_mode = args - - if "size" in kwargs: - size = kwargs.get("size") - if "padding" in kwargs: - padding = kwargs.get("padding") - if "fill_value" in kwargs: - fill_value = kwargs.get("fill_value") - if "padding_mode" in kwargs: - padding_mode = kwargs.get("padding_mode") - if "pad_if_needed" in kwargs: - pad_if_needed = kwargs.get("pad_if_needed") - - if size is None: - raise ValueError("size is not provided.") - size = check_crop_size(size) - kwargs["size"] = size - + [size, padding, pad_if_needed, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs) + check_crop_size(size) + type_check(pad_if_needed, (bool,), "pad_if_needed") if padding is not None: - padding = check_padding(padding) - kwargs["padding"] = padding + check_padding(padding) if fill_value is not None: - fill_value = check_fill_value(fill_value) - kwargs["fill_value"] = fill_value + check_fill_value(fill_value) if padding_mode is not None: - check_border_type(padding_mode) - kwargs["padding_mode"] = padding_mode - if pad_if_needed is not None: - kwargs["pad_if_needed"] = pad_if_needed + type_check(padding_mode, (Border,), "padding_mode") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -372,27 +242,13 @@ def check_random_color_adjust(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 4 * [None])[:4] - brightness, contrast, saturation, hue = args - if "brightness" in kwargs: - brightness = kwargs.get("brightness") - if "contrast" in kwargs: - contrast = kwargs.get("contrast") - if "saturation" in kwargs: - saturation = kwargs.get("saturation") - if "hue" in kwargs: - hue = kwargs.get("hue") + [brightness, contrast, saturation, hue], _ = parse_user_args(method, *args, **kwargs) + check_random_color_adjust_param(brightness, "brightness") + check_random_color_adjust_param(contrast, "contrast") + check_random_color_adjust_param(saturation, "saturation") + check_random_color_adjust_param(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False) - if brightness is not None: - kwargs["brightness"] = check_random_color_adjust_param(brightness, "brightness") - if contrast is not None: - kwargs["contrast"] = check_random_color_adjust_param(contrast, "contrast") - if saturation is not None: - kwargs["saturation"] = check_random_color_adjust_param(saturation, "saturation") - if hue is not None: - kwargs["hue"] = check_random_color_adjust_param(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False) - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -402,38 +258,19 @@ def check_random_rotation(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 5 * [None])[:5] - degrees, resample, expand, center, fill_value = args - if "degrees" in kwargs: - degrees = kwargs.get("degrees") - if "resample" in kwargs: - resample = kwargs.get("resample") - if "expand" in kwargs: - expand = kwargs.get("expand") - if "center" in kwargs: - center = kwargs.get("center") - if "fill_value" in kwargs: - fill_value = kwargs.get("fill_value") - - if degrees is None: - raise ValueError("degrees is not provided.") - degrees = check_degrees(degrees) - kwargs["degrees"] = degrees + [degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs) + check_degrees(degrees) if resample is not None: - check_inter_mode(resample) - kwargs["resample"] = resample + type_check(resample, (Inter,), "resample") if expand is not None: - check_bool(expand) - kwargs["expand"] = expand + type_check(expand, (bool,), "expand") if center is not None: - check_2tuple(center) - kwargs["center"] = center + check_2tuple(center, "center") if fill_value is not None: - fill_value = check_fill_value(fill_value) - kwargs["fill_value"] = fill_value + check_fill_value(fill_value) - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -443,16 +280,11 @@ def check_transforms_list(method): @wraps(method) def new_method(self, *args, **kwargs): - transforms = (list(args) + [None])[0] - if "transforms" in kwargs: - transforms = kwargs.get("transforms") - if transforms is None: - raise ValueError("transforms is not provided.") + [transforms], _ = parse_user_args(method, *args, **kwargs) - check_list(transforms) - kwargs["transforms"] = transforms + type_check(transforms, (list,), "transforms") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -462,21 +294,14 @@ def check_random_apply(method): @wraps(method) def new_method(self, *args, **kwargs): - transforms, prob = (list(args) + 2 * [None])[:2] - if "transforms" in kwargs: - transforms = kwargs.get("transforms") - if transforms is None: - raise ValueError("transforms is not provided.") - check_list(transforms) - kwargs["transforms"] = transforms + [transforms, prob], _ = parse_user_args(method, *args, **kwargs) + type_check(transforms, (list,), "transforms") - if "prob" in kwargs: - prob = kwargs.get("prob") if prob is not None: - check_value(prob, [0., 1.]) - kwargs["prob"] = prob + type_check(prob, (float, int,), "prob") + check_value(prob, [0., 1.], "prob") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -486,23 +311,13 @@ def check_ten_crop(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 2 * [None])[:2] - size, use_vertical_flip = args - if "size" in kwargs: - size = kwargs.get("size") - if "use_vertical_flip" in kwargs: - use_vertical_flip = kwargs.get("use_vertical_flip") - - if size is None: - raise ValueError("size is not provided.") - size = check_crop_size(size) - kwargs["size"] = size + [size, use_vertical_flip], _ = parse_user_args(method, *args, **kwargs) + check_crop_size(size) if use_vertical_flip is not None: - check_bool(use_vertical_flip) - kwargs["use_vertical_flip"] = use_vertical_flip + type_check(use_vertical_flip, (bool,), "use_vertical_flip") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -512,16 +327,13 @@ def check_num_channels(method): @wraps(method) def new_method(self, *args, **kwargs): - num_output_channels = (list(args) + [None])[0] - if "num_output_channels" in kwargs: - num_output_channels = kwargs.get("num_output_channels") + [num_output_channels], _ = parse_user_args(method, *args, **kwargs) if num_output_channels is not None: if num_output_channels not in (1, 3): raise ValueError("Number of channels of the output grayscale image" "should be either 1 or 3. Got {0}".format(num_output_channels)) - kwargs["num_output_channels"] = num_output_channels - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -531,28 +343,12 @@ def check_pad(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 3 * [None])[:3] - padding, fill_value, padding_mode = args - if "padding" in kwargs: - padding = kwargs.get("padding") - if "fill_value" in kwargs: - fill_value = kwargs.get("fill_value") - if "padding_mode" in kwargs: - padding_mode = kwargs.get("padding_mode") + [padding, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs) + check_padding(padding) + check_fill_value(fill_value) + type_check(padding_mode, (Border,), "padding_mode") - if padding is None: - raise ValueError("padding is not provided.") - padding = check_padding(padding) - kwargs["padding"] = padding - - if fill_value is not None: - fill_value = check_fill_value(fill_value) - kwargs["fill_value"] = fill_value - if padding_mode is not None: - check_border_type(padding_mode) - kwargs["padding_mode"] = padding_mode - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -562,26 +358,13 @@ def check_random_perspective(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 3 * [None])[:3] - distortion_scale, prob, interpolation = args - if "distortion_scale" in kwargs: - distortion_scale = kwargs.get("distortion_scale") - if "prob" in kwargs: - prob = kwargs.get("prob") - if "interpolation" in kwargs: - interpolation = kwargs.get("interpolation") + [distortion_scale, prob, interpolation], _ = parse_user_args(method, *args, **kwargs) - if distortion_scale is not None: - check_value(distortion_scale, [0., 1.]) - kwargs["distortion_scale"] = distortion_scale - if prob is not None: - check_value(prob, [0., 1.]) - kwargs["prob"] = prob - if interpolation is not None: - check_inter_mode(interpolation) - kwargs["interpolation"] = interpolation + check_value(distortion_scale, [0., 1.], "distortion_scale") + check_value(prob, [0., 1.], "prob") + type_check(interpolation, (Inter,), "interpolation") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -591,28 +374,13 @@ def check_mix_up(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 3 * [None])[:3] - batch_size, alpha, is_single = args - if "batch_size" in kwargs: - batch_size = kwargs.get("batch_size") - if "alpha" in kwargs: - alpha = kwargs.get("alpha") - if "is_single" in kwargs: - is_single = kwargs.get("is_single") + [batch_size, alpha, is_single], _ = parse_user_args(method, *args, **kwargs) - if batch_size is None: - raise ValueError("batch_size") - check_pos_int32(batch_size) - kwargs["batch_size"] = batch_size - if alpha is None: - raise ValueError("alpha") - check_positive(alpha) - kwargs["alpha"] = alpha - if is_single is not None: - check_type(is_single, bool) - kwargs["is_single"] = is_single + check_value(batch_size, (1, FLOAT_MAX_INTEGER)) + check_positive(alpha, "alpha") + type_check(is_single, (bool,), "is_single") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -622,41 +390,16 @@ def check_random_erasing(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 6 * [None])[:6] - prob, scale, ratio, value, inplace, max_attempts = args - if "prob" in kwargs: - prob = kwargs.get("prob") - if "scale" in kwargs: - scale = kwargs.get("scale") - if "ratio" in kwargs: - ratio = kwargs.get("ratio") - if "value" in kwargs: - value = kwargs.get("value") - if "inplace" in kwargs: - inplace = kwargs.get("inplace") - if "max_attempts" in kwargs: - max_attempts = kwargs.get("max_attempts") + [prob, scale, ratio, value, inplace, max_attempts], _ = parse_user_args(method, *args, **kwargs) - if prob is not None: - check_value(prob, [0., 1.]) - kwargs["prob"] = prob - if scale is not None: - check_range(scale, [0, FLOAT_MAX_INTEGER]) - kwargs["scale"] = scale - if ratio is not None: - check_range(ratio, [0, FLOAT_MAX_INTEGER]) - kwargs["ratio"] = ratio - if value is not None: - check_erasing_value(value) - kwargs["value"] = value - if inplace is not None: - check_bool(inplace) - kwargs["inplace"] = inplace - if max_attempts is not None: - check_pos_int32(max_attempts) - kwargs["max_attempts"] = max_attempts + check_value(prob, [0., 1.], "prob") + check_range(scale, [0, FLOAT_MAX_INTEGER]) + check_range(ratio, [0, FLOAT_MAX_INTEGER]) + check_erasing_value(value) + type_check(inplace, (bool,), "inplace") + check_value(max_attempts, (1, FLOAT_MAX_INTEGER)) - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -666,23 +409,12 @@ def check_cutout(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 2 * [None])[:2] - length, num_patches = args - if "length" in kwargs: - length = kwargs.get("length") - if "num_patches" in kwargs: - num_patches = kwargs.get("num_patches") + [length, num_patches], _ = parse_user_args(method, *args, **kwargs) - if length is None: - raise ValueError("length") - check_pos_int32(length) - kwargs["length"] = length + check_value(length, (1, FLOAT_MAX_INTEGER)) + check_value(num_patches, (1, FLOAT_MAX_INTEGER)) - if num_patches is not None: - check_pos_int32(num_patches) - kwargs["num_patches"] = num_patches - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -692,17 +424,9 @@ def check_linear_transform(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 2 * [None])[:2] - transformation_matrix, mean_vector = args - if "transformation_matrix" in kwargs: - transformation_matrix = kwargs.get("transformation_matrix") - if "mean_vector" in kwargs: - mean_vector = kwargs.get("mean_vector") - - if transformation_matrix is None: - raise ValueError("transformation_matrix is not provided.") - if mean_vector is None: - raise ValueError("mean_vector is not provided.") + [transformation_matrix, mean_vector], _ = parse_user_args(method, *args, **kwargs) + type_check(transformation_matrix, (np.ndarray,), "transformation_matrix") + type_check(mean_vector, (np.ndarray,), "mean_vector") if transformation_matrix.shape[0] != transformation_matrix.shape[1]: raise ValueError("transformation_matrix should be a square matrix. " @@ -711,10 +435,7 @@ def check_linear_transform(method): raise ValueError("mean_vector length {0} should match either one dimension of the square" "transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape)) - kwargs["transformation_matrix"] = transformation_matrix - kwargs["mean_vector"] = mean_vector - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -724,67 +445,40 @@ def check_random_affine(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 6 * [None])[:6] - degrees, translate, scale, shear, resample, fill_value = args - if "degrees" in kwargs: - degrees = kwargs.get("degrees") - if "translate" in kwargs: - translate = kwargs.get("translate") - if "scale" in kwargs: - scale = kwargs.get("scale") - if "shear" in kwargs: - shear = kwargs.get("shear") - if "resample" in kwargs: - resample = kwargs.get("resample") - if "fill_value" in kwargs: - fill_value = kwargs.get("fill_value") - - if degrees is None: - raise ValueError("degrees is not provided.") - degrees = check_degrees(degrees) - kwargs["degrees"] = degrees + [degrees, translate, scale, shear, resample, fill_value], _ = parse_user_args(method, *args, **kwargs) + check_degrees(degrees) if translate is not None: - if isinstance(translate, (tuple, list)) and len(translate) == 2: - for t in translate: - if t < 0.0 or t > 1.0: - raise ValueError("translation values should be between 0 and 1") - else: + if type_check(translate, (list, tuple), "translate"): + translate_names = ["translate_{0}".format(i) for i in range(len(translate))] + type_check_list(translate, (int, float), translate_names) + if len(translate) != 2: raise TypeError("translate should be a list or tuple of length 2.") - kwargs["translate"] = translate + for i, t in enumerate(translate): + check_value(t, [0.0, 1.0], "translate at {0}".format(i)) if scale is not None: - if isinstance(scale, (tuple, list)) and len(scale) == 2: - for s in scale: - if s <= 0: - raise ValueError("scale values should be positive") + type_check(scale, (tuple, list), "scale") + if len(scale) == 2: + for i, s in enumerate(scale): + check_positive(s, "scale[{}]".format(i)) else: raise TypeError("scale should be a list or tuple of length 2.") - kwargs["scale"] = scale if shear is not None: + type_check(shear, (numbers.Number, tuple, list), "shear") if isinstance(shear, numbers.Number): - if shear < 0: - raise ValueError("If shear is a single number, it must be positive.") - shear = (-1 * shear, shear) - elif isinstance(shear, (tuple, list)) and (len(shear) == 2 or len(shear) == 4): - # X-Axis shear with [min, max] - if len(shear) == 2: - shear = [shear[0], shear[1], 0., 0.] - elif len(shear) == 4: - shear = [s for s in shear] + check_positive(shear, "shear") else: - raise TypeError("shear should be a list or tuple and it must be of length 2 or 4.") - kwargs["shear"] = shear + if len(shear) not in (2, 4): + raise TypeError("shear must be of length 2 or 4.") + + type_check(resample, (Inter,), "resample") - if resample is not None: - check_inter_mode(resample) - kwargs["resample"] = resample if fill_value is not None: - fill_value = check_fill_value(fill_value) - kwargs["fill_value"] = fill_value + check_fill_value(fill_value) - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -794,24 +488,11 @@ def check_rescale(method): @wraps(method) def new_method(self, *args, **kwargs): - rescale, shift = (list(args) + 2 * [None])[:2] - if "rescale" in kwargs: - rescale = kwargs.get("rescale") - if "shift" in kwargs: - shift = kwargs.get("shift") - - if rescale is None: - raise ValueError("rescale is not provided.") + [rescale, shift], _ = parse_user_args(method, *args, **kwargs) check_pos_float32(rescale) - kwargs["rescale"] = rescale + type_check(shift, (numbers.Number,), "shift") - if shift is None: - raise ValueError("shift is not provided.") - if not isinstance(shift, numbers.Number): - raise TypeError("shift is not a number.") - kwargs["shift"] = shift - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -821,33 +502,16 @@ def check_uniform_augment_cpp(method): @wraps(method) def new_method(self, *args, **kwargs): - operations, num_ops = (list(args) + 2 * [None])[:2] - if "operations" in kwargs: - operations = kwargs.get("operations") - else: - raise ValueError("operations list required") - if "num_ops" in kwargs: - num_ops = kwargs.get("num_ops") - else: - num_ops = 2 + [operations, num_ops], _ = parse_user_args(method, *args, **kwargs) + type_check(num_ops, (int,), "num_ops") + check_positive(num_ops, "num_ops") - if not isinstance(num_ops, int): - raise ValueError("Number of operations should be an integer.") - - if num_ops <= 0: - raise ValueError("num_ops should be greater than zero") if num_ops > len(operations): raise ValueError("num_ops is greater than operations list size") - if not isinstance(operations, list): - raise TypeError("operations is not a python list") - for op in operations: - if not isinstance(op, TensorOp): - raise ValueError("operations list only accepts C++ operations.") + tensor_ops = ["tensor_op_{0}".format(i) for i in range(len(operations))] + type_check_list(operations, (TensorOp,), tensor_ops) - kwargs["num_ops"] = num_ops - kwargs["operations"] = operations - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -857,23 +521,11 @@ def check_bounding_box_augment_cpp(method): @wraps(method) def new_method(self, *args, **kwargs): - transform, ratio = (list(args) + 2 * [None])[:2] - if "transform" in kwargs: - transform = kwargs.get("transform") - if "ratio" in kwargs: - ratio = kwargs.get("ratio") - if not isinstance(ratio, float) and not isinstance(ratio, int): - raise ValueError("Ratio should be an int or float.") - if ratio is not None: - check_value(ratio, [0., 1.]) - kwargs["ratio"] = ratio - else: - ratio = 0.3 - if not isinstance(transform, TensorOp): - raise ValueError("Transform can only be a C++ operation.") - kwargs["transform"] = transform - kwargs["ratio"] = ratio - return method(self, **kwargs) + [transform, ratio], _ = parse_user_args(method, *args, **kwargs) + type_check(ratio, (float, int), "ratio") + check_value(ratio, [0., 1.], "ratio") + type_check(transform, (TensorOp,), "transform") + return method(self, *args, **kwargs) return new_method @@ -883,29 +535,22 @@ def check_uniform_augment_py(method): @wraps(method) def new_method(self, *args, **kwargs): - transforms, num_ops = (list(args) + 2 * [None])[:2] - if "transforms" in kwargs: - transforms = kwargs.get("transforms") - if transforms is None: - raise ValueError("transforms is not provided.") + [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs) + type_check(transforms, (list,), "transforms") + if not transforms: raise ValueError("transforms list is empty.") - check_list(transforms) + for transform in transforms: if isinstance(transform, TensorOp): raise ValueError("transform list only accepts Python operations.") - kwargs["transforms"] = transforms - if "num_ops" in kwargs: - num_ops = kwargs.get("num_ops") - if num_ops is not None: - check_type(num_ops, int) - check_positive(num_ops) - if num_ops > len(transforms): - raise ValueError("num_ops cannot be greater than the length of transforms list.") - kwargs["num_ops"] = num_ops + type_check(num_ops, (int,), "num_ops") + check_positive(num_ops, "num_ops") + if num_ops > len(transforms): + raise ValueError("num_ops cannot be greater than the length of transforms list.") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -915,22 +560,16 @@ def check_positive_degrees(method): @wraps(method) def new_method(self, *args, **kwargs): - degrees = (list(args) + [None])[0] - if "degrees" in kwargs: - degrees = kwargs.get("degrees") + [degrees], _ = parse_user_args(method, *args, **kwargs) - if degrees is not None: - if isinstance(degrees, (list, tuple)): - if len(degrees) != 2: - raise ValueError("Degrees must be a sequence with length 2.") - if degrees[0] < 0: - raise ValueError("Degrees range must be non-negative.") - if degrees[0] > degrees[1]: - raise ValueError("Degrees should be in (min,max) format. Got (max,min).") - else: - raise TypeError("Degrees must be a sequence in (min,max) format.") + if isinstance(degrees, (list, tuple)): + if len(degrees) != 2: + raise ValueError("Degrees must be a sequence with length 2.") + check_positive(degrees[0], "degrees[0]") + if degrees[0] > degrees[1]: + raise ValueError("Degrees should be in (min,max) format. Got (max,min).") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -940,18 +579,12 @@ def check_compose_list(method): @wraps(method) def new_method(self, *args, **kwargs): - transforms = (list(args) + [None])[0] - if "transforms" in kwargs: - transforms = kwargs.get("transforms") - if transforms is None: - raise ValueError("transforms is not provided.") + [transforms], _ = parse_user_args(method, *args, **kwargs) + + type_check(transforms, (list,), transforms) if not transforms: raise ValueError("transforms list is empty.") - if not isinstance(transforms, list): - raise TypeError("transforms is not a python list") - kwargs["transforms"] = transforms - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method diff --git a/tests/ut/python/dataset/test_bounding_box_augment.py b/tests/ut/python/dataset/test_bounding_box_augment.py index fe02dcebc7f..4cde4da0042 100644 --- a/tests/ut/python/dataset/test_bounding_box_augment.py +++ b/tests/ut/python/dataset/test_bounding_box_augment.py @@ -15,13 +15,15 @@ """ Testing the bounding box augment op in DE """ -from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ - config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 + import numpy as np import mindspore.log as logger import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as c_vision +from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ + config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 + GENERATE_GOLDEN = False # updated VOC dataset with correct annotations @@ -241,7 +243,7 @@ def test_bounding_box_augment_invalid_ratio_c(): operations=[test_op]) # Add column for "annotation" except ValueError as error: logger.info("Got an exception in DE: {}".format(str(error))) - assert "Input is not" in str(error) + assert "Input ratio is not within the required interval of (0.0 to 1.0)." in str(error) def test_bounding_box_augment_invalid_bounds_c(): diff --git a/tests/ut/python/dataset/test_bucket_batch_by_length.py b/tests/ut/python/dataset/test_bucket_batch_by_length.py index febcc6483f7..a30b5827cb5 100644 --- a/tests/ut/python/dataset/test_bucket_batch_by_length.py +++ b/tests/ut/python/dataset/test_bucket_batch_by_length.py @@ -17,6 +17,7 @@ import pytest import numpy as np import mindspore.dataset as ds + # generates 1 column [0], [0, 1], ..., [0, ..., n-1] def generate_sequential(n): for i in range(n): @@ -99,12 +100,12 @@ def test_bucket_batch_invalid_input(): with pytest.raises(TypeError) as info: _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, None, None, invalid_type_pad_to_bucket_boundary) - assert "Wrong input type for pad_to_bucket_boundary, should be " in str(info.value) + assert "Argument pad_to_bucket_boundary with value \"\" is not of type (,)." in str(info.value) with pytest.raises(TypeError) as info: _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, None, None, False, invalid_type_drop_remainder) - assert "Wrong input type for drop_remainder, should be " in str(info.value) + assert "Argument drop_remainder with value \"\" is not of type (,)." in str(info.value) def test_bucket_batch_multi_bucket_no_padding(): @@ -272,7 +273,6 @@ def test_bucket_batch_default_pad(): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]] - output = [] for data in dataset.create_dict_iterator(): output.append(data["col1"].tolist()) diff --git a/tests/ut/python/dataset/test_concatenate_op.py b/tests/ut/python/dataset/test_concatenate_op.py index d04ff497242..fa293c3b34f 100644 --- a/tests/ut/python/dataset/test_concatenate_op.py +++ b/tests/ut/python/dataset/test_concatenate_op.py @@ -163,18 +163,11 @@ def test_concatenate_op_negative_axis(): def test_concatenate_op_incorrect_input_dim(): - def gen(): - yield (np.array(["ss", "ad"], dtype='S'),) - prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S') - data = ds.GeneratorDataset(gen, column_names=["col"]) - concatenate_op = data_trans.Concatenate(0, prepend_tensor) - data = data.map(input_columns=["col"], operations=concatenate_op) - with pytest.raises(RuntimeError) as error_info: - for _ in data: - pass - assert "Only 1D tensors supported" in repr(error_info.value) + with pytest.raises(ValueError) as error_info: + data_trans.Concatenate(0, prepend_tensor) + assert "can only prepend 1D arrays." in repr(error_info.value) if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_exceptions.py b/tests/ut/python/dataset/test_exceptions.py index cbfa402bb06..253eb564aeb 100644 --- a/tests/ut/python/dataset/test_exceptions.py +++ b/tests/ut/python/dataset/test_exceptions.py @@ -28,9 +28,9 @@ def test_exception_01(): """ logger.info("test_exception_01") data = ds.TFRecordDataset(DATA_DIR, columns_list=["image"]) - with pytest.raises(ValueError) as info: - data = data.map(input_columns=["image"], operations=vision.Resize(100, 100)) - assert "Invalid interpolation mode." in str(info.value) + with pytest.raises(TypeError) as info: + data.map(input_columns=["image"], operations=vision.Resize(100, 100)) + assert "Argument interpolation with value 100 is not of type (,)" in str(info.value) def test_exception_02(): @@ -40,8 +40,8 @@ def test_exception_02(): logger.info("test_exception_02") num_samples = -1 with pytest.raises(ValueError) as info: - data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) - assert "num_samples cannot be less than 0" in str(info.value) + ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) + assert 'Input num_samples is not within the required interval of (0 to 2147483647).' in str(info.value) num_samples = 1 data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) diff --git a/tests/ut/python/dataset/test_from_dataset.py b/tests/ut/python/dataset/test_from_dataset.py index 207a6be6a1a..514276fe704 100644 --- a/tests/ut/python/dataset/test_from_dataset.py +++ b/tests/ut/python/dataset/test_from_dataset.py @@ -23,7 +23,8 @@ import mindspore.dataset.text as text def test_demo_basic_from_dataset(): """ this is a tutorial on how from_dataset should be used in a normal use case""" data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) - vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, special_tokens=["", ""], + vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, + special_tokens=["", ""], special_first=True) data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) res = [] @@ -127,15 +128,16 @@ def test_from_dataset_exceptions(): data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k) assert isinstance(vocab.text.Vocab) - except ValueError as e: + except (TypeError, ValueError, RuntimeError) as e: assert s in str(e), str(e) - test_config("text", (), 1, "freq_range needs to be either None or a tuple of 2 integers") - test_config("text", (2, 3), 1.2345, "top_k needs to be a positive integer") - test_config(23, (2, 3), 1.2345, "columns need to be a list of strings") - test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b") - test_config("text", (2, 3), 0, "top_k needs to be a positive integer") - test_config([123], (2, 3), 0, "columns need to be a list of strings") + test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.") + test_config("text", (2, 3), 1.2345, + "Argument top_k with value 1.2345 is not of type (, )") + test_config(23, (2, 3), 1.2345, "Argument col_0 with value 23 is not of type (,)") + test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)") + test_config("text", (2, 3), 0, "top_k needs to be positive number") + test_config([123], (2, 3), 0, "top_k needs to be positive number") if __name__ == '__main__': diff --git a/tests/ut/python/dataset/test_linear_transformation.py b/tests/ut/python/dataset/test_linear_transformation.py index 0dd25a4da1e..f932916ed83 100644 --- a/tests/ut/python/dataset/test_linear_transformation.py +++ b/tests/ut/python/dataset/test_linear_transformation.py @@ -73,6 +73,7 @@ def test_linear_transformation_op(plot=False): if plot: visualize_list(image, image_transformed) + def test_linear_transformation_md5(): """ Test LinearTransformation op: valid params (transformation_matrix, mean_vector) @@ -102,6 +103,7 @@ def test_linear_transformation_md5(): filename = "linear_transformation_01_result.npz" save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) + def test_linear_transformation_exception_01(): """ Test LinearTransformation op: transformation_matrix is not provided @@ -126,9 +128,10 @@ def test_linear_transformation_exception_01(): ] transform = py_vision.ComposeOp(transforms) data1 = data1.map(input_columns=["image"], operations=transform()) - except ValueError as e: + except TypeError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "not provided" in str(e) + assert "Argument transformation_matrix with value None is not of type (,)" in str(e) + def test_linear_transformation_exception_02(): """ @@ -154,9 +157,10 @@ def test_linear_transformation_exception_02(): ] transform = py_vision.ComposeOp(transforms) data1 = data1.map(input_columns=["image"], operations=transform()) - except ValueError as e: + except TypeError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "not provided" in str(e) + assert "Argument mean_vector with value None is not of type (,)" in str(e) + def test_linear_transformation_exception_03(): """ @@ -187,6 +191,7 @@ def test_linear_transformation_exception_03(): logger.info("Got an exception in DE: {}".format(str(e))) assert "square matrix" in str(e) + def test_linear_transformation_exception_04(): """ Test LinearTransformation op: mean_vector does not match dimension of transformation_matrix @@ -199,7 +204,7 @@ def test_linear_transformation_exception_04(): weight = 50 dim = 3 * height * weight transformation_matrix = np.ones([dim, dim]) - mean_vector = np.zeros(dim-1) + mean_vector = np.zeros(dim - 1) # Generate dataset data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) @@ -216,6 +221,7 @@ def test_linear_transformation_exception_04(): logger.info("Got an exception in DE: {}".format(str(e))) assert "should match" in str(e) + if __name__ == '__main__': test_linear_transformation_op(plot=True) test_linear_transformation_md5() diff --git a/tests/ut/python/dataset/test_minddataset_exception.py b/tests/ut/python/dataset/test_minddataset_exception.py index b15944d76b8..5ecaeff13ac 100644 --- a/tests/ut/python/dataset/test_minddataset_exception.py +++ b/tests/ut/python/dataset/test_minddataset_exception.py @@ -184,24 +184,26 @@ def test_minddataset_invalidate_num_shards(): create_cv_mindrecord(1) columns_list = ["data", "label"] num_readers = 4 - with pytest.raises(Exception, match="shard_id is invalid, "): + with pytest.raises(Exception) as error_info: data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2) num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 + assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info) + os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) - def test_minddataset_invalidate_shard_id(): create_cv_mindrecord(1) columns_list = ["data", "label"] num_readers = 4 - with pytest.raises(Exception, match="shard_id is invalid, "): + with pytest.raises(Exception) as error_info: data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1) num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 + assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info) os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) @@ -210,17 +212,19 @@ def test_minddataset_shard_id_bigger_than_num_shard(): create_cv_mindrecord(1) columns_list = ["data", "label"] num_readers = 4 - with pytest.raises(Exception, match="shard_id is invalid, "): + with pytest.raises(Exception) as error_info: data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2) num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 + assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info) - with pytest.raises(Exception, match="shard_id is invalid, "): + with pytest.raises(Exception) as error_info: data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 + assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info) os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) diff --git a/tests/ut/python/dataset/test_ngram_op.py b/tests/ut/python/dataset/test_ngram_op.py index 73b2702378e..8887b675000 100644 --- a/tests/ut/python/dataset/test_ngram_op.py +++ b/tests/ut/python/dataset/test_ngram_op.py @@ -15,9 +15,9 @@ """ Testing Ngram in mindspore.dataset """ +import numpy as np import mindspore.dataset as ds import mindspore.dataset.text as text -import numpy as np def test_multiple_ngrams(): @@ -61,7 +61,7 @@ def test_simple_ngram(): yield (np.array(line.split(" "), dtype='S'),) dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"]) - dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=None)) + dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=" ")) i = 0 for data in dataset.create_dict_iterator(): @@ -72,7 +72,7 @@ def test_simple_ngram(): def test_corner_cases(): """ testing various corner cases and exceptions""" - def test_config(input_line, output_line, n, l_pad=None, r_pad=None, sep=None): + def test_config(input_line, output_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "): def gen(texts): yield (np.array(texts.split(" "), dtype='S'),) @@ -93,7 +93,7 @@ def test_corner_cases(): try: test_config("Yours to Discover", "", [0, [1]]) except Exception as e: - assert "ngram needs to be a positive number" in str(e) + assert "Argument gram[1] with value [1] is not of type (,)" in str(e) # test empty n try: test_config("Yours to Discover", "", []) diff --git a/tests/ut/python/dataset/test_normalizeOp.py b/tests/ut/python/dataset/test_normalizeOp.py index af97ee0c088..d5ebc799f91 100644 --- a/tests/ut/python/dataset/test_normalizeOp.py +++ b/tests/ut/python/dataset/test_normalizeOp.py @@ -279,7 +279,7 @@ def test_normalize_exception_invalid_range_py(): _ = py_vision.Normalize([0.75, 1.25, 0.5], [0.1, 0.18, 1.32]) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input is not within the required range" in str(e) + assert "Input mean_value is not within the required interval of (0.0 to 1.0)." in str(e) def test_normalize_grayscale_md5_01(): diff --git a/tests/ut/python/dataset/test_pad_end_op.py b/tests/ut/python/dataset/test_pad_end_op.py index 5742d736659..c25d6b9a95b 100644 --- a/tests/ut/python/dataset/test_pad_end_op.py +++ b/tests/ut/python/dataset/test_pad_end_op.py @@ -61,6 +61,10 @@ def test_pad_end_exceptions(): pad_compare([3, 4, 5], ["2"], 1, []) assert "a value in the list is not an integer." in str(info.value) + with pytest.raises(TypeError) as info: + pad_compare([1, 2], 3, -1, [1, 2, -1]) + assert "Argument pad_end with value 3 is not of type (,)" in str(info.value) + if __name__ == "__main__": test_pad_end_basics() diff --git a/tests/ut/python/dataset/test_random_affine.py b/tests/ut/python/dataset/test_random_affine.py index b856684ed13..ec829eb53a7 100644 --- a/tests/ut/python/dataset/test_random_affine.py +++ b/tests/ut/python/dataset/test_random_affine.py @@ -103,7 +103,7 @@ def test_random_affine_exception_negative_degrees(): _ = py_vision.RandomAffine(degrees=-15) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "If degrees is a single number, it cannot be negative." + assert str(e) == "Input degrees is not within the required interval of (0 to inf)." def test_random_affine_exception_translation_range(): @@ -115,7 +115,7 @@ def test_random_affine_exception_translation_range(): _ = py_vision.RandomAffine(degrees=15, translate=(0.1, 1.5)) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "translation values should be between 0 and 1" + assert str(e) == "Input translate at 1 is not within the required interval of (0.0 to 1.0)." def test_random_affine_exception_scale_value(): @@ -127,7 +127,7 @@ def test_random_affine_exception_scale_value(): _ = py_vision.RandomAffine(degrees=15, scale=(0.0, 1.1)) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "scale values should be positive" + assert str(e) == "Input scale[0] must be greater than 0." def test_random_affine_exception_shear_value(): @@ -139,7 +139,7 @@ def test_random_affine_exception_shear_value(): _ = py_vision.RandomAffine(degrees=15, shear=-5) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "If shear is a single number, it must be positive." + assert str(e) == "Input shear must be greater than 0." def test_random_affine_exception_degrees_size(): @@ -165,7 +165,9 @@ def test_random_affine_exception_translate_size(): _ = py_vision.RandomAffine(degrees=15, translate=(0.1)) except TypeError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "translate should be a list or tuple of length 2." + assert str( + e) == "Argument translate with value 0.1 is not of type (," \ + " )." def test_random_affine_exception_scale_size(): @@ -178,7 +180,8 @@ def test_random_affine_exception_scale_size(): _ = py_vision.RandomAffine(degrees=15, scale=(0.5)) except TypeError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "scale should be a list or tuple of length 2." + assert str(e) == "Argument scale with value 0.5 is not of type (," \ + " )." def test_random_affine_exception_shear_size(): @@ -191,7 +194,7 @@ def test_random_affine_exception_shear_size(): _ = py_vision.RandomAffine(degrees=15, shear=(-5, 5, 10)) except TypeError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "shear should be a list or tuple and it must be of length 2 or 4." + assert str(e) == "shear must be of length 2 or 4." if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_random_color.py b/tests/ut/python/dataset/test_random_color.py index 45847ba6534..0015e8498f6 100644 --- a/tests/ut/python/dataset/test_random_color.py +++ b/tests/ut/python/dataset/test_random_color.py @@ -97,7 +97,7 @@ def test_random_color_md5(): data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) transforms = F.ComposeOp([F.Decode(), - F.RandomColor((0.5, 1.5)), + F.RandomColor((0.1, 1.9)), F.ToTensor()]) data = data.map(input_columns="image", operations=transforms()) diff --git a/tests/ut/python/dataset/test_random_crop_and_resize.py b/tests/ut/python/dataset/test_random_crop_and_resize.py index de039e6d82e..486d2cd5ed1 100644 --- a/tests/ut/python/dataset/test_random_crop_and_resize.py +++ b/tests/ut/python/dataset/test_random_crop_and_resize.py @@ -232,7 +232,7 @@ def test_random_crop_and_resize_04_c(): data = data.map(input_columns=["image"], operations=random_crop_and_resize_op) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input range is not valid" in str(e) + assert "Input is not within the required interval of (0 to 16777216)." in str(e) def test_random_crop_and_resize_04_py(): @@ -255,7 +255,7 @@ def test_random_crop_and_resize_04_py(): data = data.map(input_columns=["image"], operations=transform()) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input range is not valid" in str(e) + assert "Input is not within the required interval of (0 to 16777216)." in str(e) def test_random_crop_and_resize_05_c(): @@ -275,7 +275,7 @@ def test_random_crop_and_resize_05_c(): data = data.map(input_columns=["image"], operations=random_crop_and_resize_op) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input range is not valid" in str(e) + assert "Input is not within the required interval of (0 to 16777216)." in str(e) def test_random_crop_and_resize_05_py(): @@ -298,7 +298,7 @@ def test_random_crop_and_resize_05_py(): data = data.map(input_columns=["image"], operations=transform()) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input range is not valid" in str(e) + assert "Input is not within the required interval of (0 to 16777216)." in str(e) def test_random_crop_and_resize_comp(plot=False): diff --git a/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py b/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py index 46c45ecc36c..599acc95609 100644 --- a/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py +++ b/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py @@ -159,7 +159,7 @@ def test_random_resized_crop_with_bbox_op_invalid_c(): except ValueError as err: logger.info("Got an exception in DE: {}".format(str(err))) - assert "Input range is not valid" in str(err) + assert "Input is not within the required interval of (0 to 16777216)." in str(err) def test_random_resized_crop_with_bbox_op_invalid2_c(): @@ -185,7 +185,7 @@ def test_random_resized_crop_with_bbox_op_invalid2_c(): except ValueError as err: logger.info("Got an exception in DE: {}".format(str(err))) - assert "Input range is not valid" in str(err) + assert "Input is not within the required interval of (0 to 16777216)." in str(err) def test_random_resized_crop_with_bbox_op_bad_c(): diff --git a/tests/ut/python/dataset/test_random_grayscale.py b/tests/ut/python/dataset/test_random_grayscale.py index 83514a55f6f..4cb25c3a3a9 100644 --- a/tests/ut/python/dataset/test_random_grayscale.py +++ b/tests/ut/python/dataset/test_random_grayscale.py @@ -179,7 +179,7 @@ def test_random_grayscale_invalid_param(): data = data.map(input_columns=["image"], operations=transform()) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input is not within the required range" in str(e) + assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e) if __name__ == "__main__": test_random_grayscale_valid_prob(True) diff --git a/tests/ut/python/dataset/test_random_horizontal_flip.py b/tests/ut/python/dataset/test_random_horizontal_flip.py index 1272148e4fc..ef4f5b8eb6f 100644 --- a/tests/ut/python/dataset/test_random_horizontal_flip.py +++ b/tests/ut/python/dataset/test_random_horizontal_flip.py @@ -141,7 +141,7 @@ def test_random_horizontal_invalid_prob_c(): data = data.map(input_columns=["image"], operations=random_horizontal_op) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input is not" in str(e) + assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e) def test_random_horizontal_invalid_prob_py(): @@ -164,7 +164,7 @@ def test_random_horizontal_invalid_prob_py(): data = data.map(input_columns=["image"], operations=transform()) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input is not" in str(e) + assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e) def test_random_horizontal_comp(plot=False): diff --git a/tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py b/tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py index 02126f25ac6..4fd51a7a035 100644 --- a/tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py +++ b/tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py @@ -190,7 +190,7 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c(): operations=[test_op]) # Add column for "annotation" except ValueError as error: logger.info("Got an exception in DE: {}".format(str(error))) - assert "Input is not" in str(error) + assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(error) def test_random_horizontal_flip_with_bbox_invalid_bounds_c(): diff --git a/tests/ut/python/dataset/test_random_perspective.py b/tests/ut/python/dataset/test_random_perspective.py index 66329ddb909..992bf2b2227 100644 --- a/tests/ut/python/dataset/test_random_perspective.py +++ b/tests/ut/python/dataset/test_random_perspective.py @@ -107,7 +107,7 @@ def test_random_perspective_exception_distortion_scale_range(): _ = py_vision.RandomPerspective(distortion_scale=1.5) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "Input is not within the required range" + assert str(e) == "Input distortion_scale is not within the required interval of (0.0 to 1.0)." def test_random_perspective_exception_prob_range(): @@ -119,7 +119,7 @@ def test_random_perspective_exception_prob_range(): _ = py_vision.RandomPerspective(prob=1.2) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "Input is not within the required range" + assert str(e) == "Input prob is not within the required interval of (0.0 to 1.0)." if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_random_resize_with_bbox.py b/tests/ut/python/dataset/test_random_resize_with_bbox.py index 8e2dab33e13..94f9d12427b 100644 --- a/tests/ut/python/dataset/test_random_resize_with_bbox.py +++ b/tests/ut/python/dataset/test_random_resize_with_bbox.py @@ -163,7 +163,7 @@ def test_random_resize_with_bbox_op_invalid_c(): except ValueError as err: logger.info("Got an exception in DE: {}".format(str(err))) - assert "Input is not" in str(err) + assert "Input is not within the required interval of (1 to 16777216)." in str(err) try: # one of the size values is zero @@ -171,7 +171,7 @@ def test_random_resize_with_bbox_op_invalid_c(): except ValueError as err: logger.info("Got an exception in DE: {}".format(str(err))) - assert "Input is not" in str(err) + assert "Input size at dim 0 is not within the required interval of (1 to 2147483647)." in str(err) try: # negative value for resize @@ -179,7 +179,7 @@ def test_random_resize_with_bbox_op_invalid_c(): except ValueError as err: logger.info("Got an exception in DE: {}".format(str(err))) - assert "Input is not" in str(err) + assert "Input is not within the required interval of (1 to 16777216)." in str(err) try: # invalid input shape diff --git a/tests/ut/python/dataset/test_random_sharpness.py b/tests/ut/python/dataset/test_random_sharpness.py index d8207ff099b..22e5c66f1a1 100644 --- a/tests/ut/python/dataset/test_random_sharpness.py +++ b/tests/ut/python/dataset/test_random_sharpness.py @@ -97,7 +97,7 @@ def test_random_sharpness_md5(): # define map operations transforms = [ F.Decode(), - F.RandomSharpness((0.5, 1.5)), + F.RandomSharpness((0.1, 1.9)), F.ToTensor() ] transform = F.ComposeOp(transforms) diff --git a/tests/ut/python/dataset/test_random_vertical_flip.py b/tests/ut/python/dataset/test_random_vertical_flip.py index 2fc9b127745..a3d02959fdd 100644 --- a/tests/ut/python/dataset/test_random_vertical_flip.py +++ b/tests/ut/python/dataset/test_random_vertical_flip.py @@ -141,7 +141,7 @@ def test_random_vertical_invalid_prob_c(): data = data.map(input_columns=["image"], operations=random_horizontal_op) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input is not" in str(e) + assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e) def test_random_vertical_invalid_prob_py(): @@ -163,7 +163,7 @@ def test_random_vertical_invalid_prob_py(): data = data.map(input_columns=["image"], operations=transform()) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input is not" in str(e) + assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e) def test_random_vertical_comp(plot=False): diff --git a/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py b/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py index be6778b1c63..490dc3e419b 100644 --- a/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py +++ b/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py @@ -191,7 +191,7 @@ def test_random_vertical_flip_with_bbox_op_invalid_c(): except ValueError as err: logger.info("Got an exception in DE: {}".format(str(err))) - assert "Input is not" in str(err) + assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(err) def test_random_vertical_flip_with_bbox_op_bad_c(): diff --git a/tests/ut/python/dataset/test_resize_with_bbox.py b/tests/ut/python/dataset/test_resize_with_bbox.py index 5fb957aa323..3bb731ee970 100644 --- a/tests/ut/python/dataset/test_resize_with_bbox.py +++ b/tests/ut/python/dataset/test_resize_with_bbox.py @@ -150,7 +150,7 @@ def test_resize_with_bbox_op_invalid_c(): # invalid interpolation value c_vision.ResizeWithBBox(400, interpolation="invalid") - except ValueError as err: + except TypeError as err: logger.info("Got an exception in DE: {}".format(str(err))) assert "interpolation" in str(err) diff --git a/tests/ut/python/dataset/test_shuffle.py b/tests/ut/python/dataset/test_shuffle.py index 56cc65a23b2..460c491ca1b 100644 --- a/tests/ut/python/dataset/test_shuffle.py +++ b/tests/ut/python/dataset/test_shuffle.py @@ -154,7 +154,7 @@ def test_shuffle_exception_01(): except Exception as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "buffer_size" in str(e) + assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e) def test_shuffle_exception_02(): @@ -172,7 +172,7 @@ def test_shuffle_exception_02(): except Exception as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "buffer_size" in str(e) + assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e) def test_shuffle_exception_03(): @@ -190,7 +190,7 @@ def test_shuffle_exception_03(): except Exception as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "buffer_size" in str(e) + assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e) def test_shuffle_exception_05(): diff --git a/tests/ut/python/dataset/test_ten_crop.py b/tests/ut/python/dataset/test_ten_crop.py index 7bffea5cc9d..d196bc05cf5 100644 --- a/tests/ut/python/dataset/test_ten_crop.py +++ b/tests/ut/python/dataset/test_ten_crop.py @@ -62,7 +62,7 @@ def util_test_ten_crop(crop_size, vertical_flip=False, plot=False): logger.info("dtype of image_2: {}".format(image_2.dtype)) if plot: - visualize_list(np.array([image_1]*10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1)) + visualize_list(np.array([image_1] * 10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1)) # The output data should be of a 4D tensor shape, a stack of 10 images. assert len(image_2.shape) == 4 @@ -144,7 +144,7 @@ def test_ten_crop_invalid_size_error_msg(): vision.TenCrop(0), lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images ] - error_msg = "Input is not within the required range" + error_msg = "Input is not within the required interval of (1 to 16777216)." assert error_msg == str(info.value) with pytest.raises(ValueError) as info: diff --git a/tests/ut/python/dataset/test_uniform_augment.py b/tests/ut/python/dataset/test_uniform_augment.py index a26b6472656..2edd832d79a 100644 --- a/tests/ut/python/dataset/test_uniform_augment.py +++ b/tests/ut/python/dataset/test_uniform_augment.py @@ -169,7 +169,9 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2): except Exception as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "operations" in str(e) + assert "Argument tensor_op_5 with value" \ + " ,)" in str(e) def test_cpp_uniform_augment_exception_large_numops(num_ops=6): @@ -209,7 +211,7 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0): except Exception as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "num_ops" in str(e) + assert "Input num_ops must be greater than 0" in str(e) def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5): @@ -229,7 +231,7 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5): except Exception as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "integer" in str(e) + assert "Argument num_ops with value 2.5 is not of type (,)" in str(e) def test_cpp_uniform_augment_random_crop_badinput(num_ops=1): diff --git a/tests/ut/python/dataset/util.py b/tests/ut/python/dataset/util.py index 432c01ef46e..11c57354065 100644 --- a/tests/ut/python/dataset/util.py +++ b/tests/ut/python/dataset/util.py @@ -314,14 +314,15 @@ def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows= if len(orig) != len(aug) or not orig: return - batch_size = int(len(orig)/plot_rows) # creates batches of images to plot together + batch_size = int(len(orig) / plot_rows) # creates batches of images to plot together split_point = batch_size * plot_rows orig, aug = np.array(orig), np.array(aug) if len(orig) > plot_rows: # Create batches of required size and add remainder to last batch - orig = np.split(orig[:split_point], batch_size) + ([orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added + orig = np.split(orig[:split_point], batch_size) + ( + [orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else []) else: orig = [orig] @@ -336,7 +337,8 @@ def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows= for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])): cur_ix = base_ix + x - (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) # select plotting axes based on number of image rows on plot - else case when 1 row + # select plotting axes based on number of image rows on plot - else case when 1 row + (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) axA.imshow(dataA["image"]) add_bounding_boxes(axA, dataA[annot_name])